diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a309ee35ee58216bebd705d2b70f45c7f8dee70c..a6ea0cc0a83a8c907ab7cc6c1e97263deb2dc058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -928,12 +928,17 @@ class Analyzer( // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { + case n: Nondeterministic => n + } + leafNondeterministic.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne } - new TreeNodeRef(e) -> ne }.toMap val newPlan = p.transformExpressions { case e => nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 03e36c7871bcfa316b7303f8c039e55a596dbead..8fc182607ce682ccab94af6d84cbd17187a0d476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -201,11 +201,9 @@ trait Nondeterministic extends Expression { private[this] var initialized = false - final def initialize(): Unit = { - if (!initialized) { - initInternal() - initialized = true - } + final def setInitialValues(): Unit = { + initInternal() + initialized = true } protected def initInternal(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 27d6ff587ab716173fee6f092ad922a368ee59fc..b3beb7e28f208af2f873e13413bab5a4c944f091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -32,7 +32,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) @@ -63,7 +63,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5bfe1cad24a3ebb40291d9d2193a79f155b9ca12..ab7d3afce8f2ec4a211012da75a8aef3eae10ec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,7 +31,7 @@ object InterpretedPredicate { def create(expression: Expression): (InternalRow => Boolean) = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ed645b618dc9ba97d358d6562d923e4b7e4ba1e0..4589facb49b766973528c2854c0193012ba30fbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -153,7 +153,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(4).dataType == DoubleType) } - test("pull out nondeterministic expressions from unary LogicalPlan") { + test("pull out nondeterministic expressions from RepartitionByExpression") { val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) val projected = Alias(Rand(33), "_nondeterministic")() val expected = @@ -162,4 +162,14 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("pull out nondeterministic expressions from Sort") { + val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false, + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 0c8611d5ddefae69f32f49933540e164bb54500d..3c05e5c3b833c27b603c7e9714dd5ddda334e38e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -65,7 +65,7 @@ trait ExpressionEvalHelper { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } expression.eval(inputRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3151e071b19eabd913fd36c8265446fd6015f4fc..97beae2f85c50721850ec2a5099e2d42ee039de4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,33 +33,28 @@ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - def sqlContext: SQLContext = ctx + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - testData.select('nonExistentName) - - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) + } } test("dataframe toString") { @@ -77,21 +72,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() - val badPlan = testData.select('badColumn) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() - assert(badPlan.toString contains badPlan.queryExecution.toString, - "toString on bad query plans should include the query execution but was:\n" + - badPlan.toString) + val badPlan = testData.select('badColumn) - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + } } test("access complex data") { @@ -107,8 +99,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -344,7 +336,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -425,7 +417,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -519,7 +511,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -609,21 +601,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } - test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -635,14 +623,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -662,7 +650,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = sqlContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -710,49 +698,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = sqlContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = sqlContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = sqlContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = sqlContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = ctx.range(10).select("id") + val res10 = sqlContext.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = ctx.range(-1).select("id") + val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) } @@ -819,13 +807,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = ctx.read.json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -845,11 +833,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(ctx, OneRowRelation).registerTempTable("one_row") + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) + } + + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } }