diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6c0196c21a0d16d04a9818f30855b7066ba5d500..0cff21ca618b44abb6cc051e8dc96ccb062d606a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv} * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[joins] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation { def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by @@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. -private[joins] object HashedRelation { +private[execution] object HashedRelation { def apply( input: Iterator[InternalRow], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..b31c5a863832e466086233c26d90fa6f4030215b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} + +case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToSafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToSafe = FromUnsafeProjection(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToSafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..de2f4e661ab440d3e4060737a9af7cac227ef220 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} + +case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToUnsafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToUnsafe = UnsafeProjection.create(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToUnsafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index 81dd37c7da733f7446dfa33f0545c9a53fe5d810..dd1113b6726cfedd41756a6e6a473d0185c2df1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { +case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var predicate: (InternalRow) => Boolean = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..a3e68d6a7c3414bbae3a89677d08e923815e764e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. + */ +case class HashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) extends BinaryLocalNode(conf) { + + private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (left, leftKeys, right, rightKeys) + case BuildRight => (right, rightKeys, left, leftKeys) + } + + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ + private[this] var currentMatchPosition: Int = -1 + + private[this] var joinRow: JoinedRow = _ + private[this] var resultProjection: (InternalRow) => InternalRow = _ + + private[this] var hashed: HashedRelation = _ + private[this] var joinKeys: Projection = _ + + override def output: Seq[Attribute] = left.output ++ right.output + + private[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(schema)) + } + + private[this] def buildSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + newMutableProjection(buildKeys, buildNode.output)() + } + } + + private[this] def streamSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(streamedKeys, streamedNode.output) + } else { + newMutableProjection(streamedKeys, streamedNode.output)() + } + } + + override def open(): Unit = { + buildNode.open() + hashed = HashedRelation.apply( + new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator) + streamedNode.open() + joinRow = new JoinedRow + resultProjection = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + joinKeys = streamSideKeyGenerator + } + + override def next(): Boolean = { + currentMatchPosition += 1 + if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { + fetchNextMatch() + } else { + true + } + } + + /** + * Populate `currentHashMatches` with build-side rows matching the next streamed row. + * @return whether matches are found such that subsequent calls to `fetch` are valid. + */ + private def fetchNextMatch(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamedNode.next()) { + currentStreamedRow = streamedNode.fetch() + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashed.get(key) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + + override def fetch(): InternalRow = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + resultProjection(ret) + } + + override def close(): Unit = { + left.close() + right.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index fffc52abf6dd52b37442c8f333d90d91abab9df0..401b10a5ed307a13a9eb3105c2ea942fa0aa1772 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { +case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { private[this] var count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 1c4469acbf264282b27d454a7a3286221005e7a1..c4f8ae304db39ba3633f8fe24fe848f14acf8e06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Row +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode extends TreeNode[LocalNode] { +abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { + + protected val codegenEnabled: Boolean = conf.codegenEnabled + + protected val unsafeEnabled: Boolean = conf.unsafeEnabled + + lazy val schema: StructType = StructType.fromAttributes(output) + + private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") def output: Seq[Attribute] @@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] { } result } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } -abstract class LeafLocalNode extends LocalNode { +abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { override def children: Seq[LocalNode] = Seq.empty } -abstract class UnaryLocalNode extends LocalNode { +abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { def child: LocalNode override def children: Seq[LocalNode] = Seq(child) } + +abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} + +/** + * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. + */ +private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { + private var nextRow: InternalRow = _ + + override def hasNext: Boolean = { + if (nextRow == null) { + val res = localNode.next() + if (res) { + nextRow = localNode.fetch() + } + res + } else { + true + } + } + + override def next(): InternalRow = { + if (hasNext) { + val res = nextRow + nextRow = null + res + } else { + throw new NoSuchElementException + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 9b8a4fe4930261d16790b0cfc53c1b63d98e3269..11529d6dd9b837febbdb0b29e1b5347faf3ef3bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} -case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { +case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var project: UnsafeProjection = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 242cb66e07b7fb84d0562f05dc9e29cb02a97c4e..b8467f6ae58e08dbe14a6f56240aa88dfd0bcdb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute /** * An operator that scans some local data collection in the form of Scala Seq. */ -case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { +case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) + extends LeafLocalNode(conf) { private[this] var iterator: Iterator[InternalRow] = _ private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index ba4aa7671aebd012cea8fda8422baa8141f623bb..0f2b8303e7372281ac18e790b573de910724ad2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class UnionNode(children: Seq[LocalNode]) extends LocalNode { +case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { override def output: Seq[Attribute] = children.head.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 07209f3779248ef88dd3d751e87cfd1478f74562..a12670e347c25a556d3d4563748135b91342f74b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (testData.col("key") % 2) === 0 checkAnswer( testData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), testData.filter(condition).collect() ) } @@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (emptyTestData.col("key") % 2) === 0 checkAnswer( emptyTestData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), emptyTestData.filter(condition).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..43b6f06aead88964f3962ba48a5b8883cef12c34 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,130 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.execution.joins + +class HashJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + + def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { + test(s"$suiteName: inner join with one match per row") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(upperCaseData.col("N").expr), + Seq(lowerCaseData.col("n").expr), + joins.BuildLeft, + node1, + node2) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N").collect() + ) + } + } + + test(s"$suiteName: inner join with multiple matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + x.join(y).where($"x.a" === $"y.a").collect() + ) + } + } + + test(s"$suiteName: inner join, no matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + Nil + ) + } + } + + test(s"$suiteName: big inner join, 4 matches per row") { + withSQLConf(confPairs: _*) { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") + + checkAnswer2( + bigDataX, + bigDataY, + wrapForUnsafe( + (node1, node2) => + HashJoinNode( + conf, + Seq(bigDataX.col("key").expr), + Seq(bigDataY.col("key").expr), + joins.BuildLeft, + node1, + node2) + ), + bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) + } + } + } + + joinSuite( + "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 523c02f4a6014b97f982f10d9131bf43893d89e9..3b183902007e47ea7456d879e439d176da7a54fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("basic") { checkAnswer( testData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), testData.limit(10).collect() ) } @@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("empty") { checkAnswer( emptyTestData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), emptyTestData.limit(10).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8479a3d0c445f39f601e055408a2..b95d4ea7f8f2af275fe67ffa17aefce0b90dd908 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class LocalNodeTest extends SparkFunSuite { +class LocalNodeTest extends SparkFunSuite with SharedSQLContext { + + def conf: SQLConf = sqlContext.conf /** * Runs the LocalNode and makes sure the answer matches the expected result. @@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite { protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { new SeqScanNode( + conf, df.queryExecution.sparkPlan.output, df.queryExecution.toRdd.map(_.copy()).collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index ffcf092e2c66a6ad13e2ffc6680e065d636287d3..38e0a230c46d8db0d0c3c9e9076b36f76a6d2567 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( testData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), testData.select("value", "key").collect() ) } @@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( emptyTestData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), emptyTestData.select("value", "key").collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 34670287c3e1d270d704f0d00f18849ca2122c41..eedd7320900f90faa6ff51d26fd5ac7e5074cf68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( testData, testData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), testData.unionAll(testData).collect() ) } @@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( emptyTestData, emptyTestData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), emptyTestData.unionAll(emptyTestData).collect() ) } @@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { emptyTestData, emptyTestData, testData, emptyTestData) doCheckAnswer( dfs, - nodes => UnionNode(nodes), + nodes => UnionNode(conf, nodes), dfs.reduce(_.unionAll(_)).collect() ) }