Skip to content
Snippets Groups Projects
Commit d88abb7e authored by zsxwing's avatar zsxwing Committed by Andrew Or
Browse files

[SPARK-9990] [SQL] Create local hash join operator

This PR includes the following changes:
- Add SQLConf to LocalNode
- Add HashJoinNode
- Add ConvertToUnsafeNode and ConvertToSafeNode.scala to test unsafe hash join.

Author: zsxwing <zsxwing@gmail.com>

Closes #8535 from zsxwing/SPARK-9990.
parent a5ef2d06
No related branches found
No related tags found
No related merge requests found
Showing
with 455 additions and 24 deletions
...@@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv} ...@@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object. * object.
*/ */
private[joins] sealed trait HashedRelation { private[execution] sealed trait HashedRelation {
def get(key: InternalRow): Seq[InternalRow] def get(key: InternalRow): Seq[InternalRow]
// This is a helper method to implement Externalizable, and is used by // This is a helper method to implement Externalizable, and is used by
...@@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR ...@@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
private[joins] object HashedRelation { private[execution] object HashedRelation {
def apply( def apply(
input: Iterator[InternalRow], input: Iterator[InternalRow],
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection}
case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {
override def output: Seq[Attribute] = child.output
private[this] var convertToSafe: Projection = _
override def open(): Unit = {
child.open()
convertToSafe = FromUnsafeProjection(child.schema)
}
override def next(): Boolean = child.next()
override def fetch(): InternalRow = convertToSafe(child.fetch())
override def close(): Unit = child.close()
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection}
case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {
override def output: Seq[Attribute] = child.output
private[this] var convertToUnsafe: Projection = _
override def open(): Unit = {
child.open()
convertToUnsafe = UnsafeProjection.create(child.schema)
}
override def next(): Boolean = child.next()
override def fetch(): InternalRow = convertToUnsafe(child.fetch())
override def close(): Unit = child.close()
}
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode)
extends UnaryLocalNode(conf) {
private[this] var predicate: (InternalRow) => Boolean = _ private[this] var predicate: (InternalRow) => Boolean = _
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
*/
case class HashJoinNode(
conf: SQLConf,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: LocalNode,
right: LocalNode) extends BinaryLocalNode(conf) {
private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
case BuildLeft => (left, leftKeys, right, rightKeys)
case BuildRight => (right, rightKeys, left, leftKeys)
}
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1
private[this] var joinRow: JoinedRow = _
private[this] var resultProjection: (InternalRow) => InternalRow = _
private[this] var hashed: HashedRelation = _
private[this] var joinKeys: Projection = _
override def output: Seq[Attribute] = left.output ++ right.output
private[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(schema))
}
private[this] def buildSideKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
newMutableProjection(buildKeys, buildNode.output)()
}
}
private[this] def streamSideKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedNode.output)
} else {
newMutableProjection(streamedKeys, streamedNode.output)()
}
}
override def open(): Unit = {
buildNode.open()
hashed = HashedRelation.apply(
new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator)
streamedNode.open()
joinRow = new JoinedRow
resultProjection = {
if (isUnsafeMode) {
UnsafeProjection.create(schema)
} else {
identity[InternalRow]
}
}
joinKeys = streamSideKeyGenerator
}
override def next(): Boolean = {
currentMatchPosition += 1
if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) {
fetchNextMatch()
} else {
true
}
}
/**
* Populate `currentHashMatches` with build-side rows matching the next streamed row.
* @return whether matches are found such that subsequent calls to `fetch` are valid.
*/
private def fetchNextMatch(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1
while (currentHashMatches == null && streamedNode.next()) {
currentStreamedRow = streamedNode.fetch()
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashed.get(key)
}
}
if (currentHashMatches == null) {
false
} else {
currentMatchPosition = 0
true
}
}
override def fetch(): InternalRow = {
val ret = buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
resultProjection(ret)
}
override def close(): Unit = {
left.close()
right.close()
}
}
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) {
private[this] var count = 0 private[this] var count = 0
......
...@@ -17,9 +17,13 @@ ...@@ -17,9 +17,13 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.Row import scala.util.control.NonFatal
import org.apache.spark.Logging
import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
...@@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType ...@@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called. * Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called. * After consuming the iterator, close function must be called.
*/ */
abstract class LocalNode extends TreeNode[LocalNode] { abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
protected val codegenEnabled: Boolean = conf.codegenEnabled
protected val unsafeEnabled: Boolean = conf.unsafeEnabled
lazy val schema: StructType = StructType.fromAttributes(output)
private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
def output: Seq[Attribute] def output: Seq[Attribute]
...@@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] { ...@@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] {
} }
result result
} }
protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
try {
GenerateMutableProjection.generate(expressions, inputSchema)
} catch {
case NonFatal(e) =>
if (isTesting) {
throw e
} else {
log.error("Failed to generate mutable projection, fallback to interpreted", e)
() => new InterpretedMutableProjection(expressions, inputSchema)
}
}
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
}
}
} }
abstract class LeafLocalNode extends LocalNode { abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) {
override def children: Seq[LocalNode] = Seq.empty override def children: Seq[LocalNode] = Seq.empty
} }
abstract class UnaryLocalNode extends LocalNode { abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
def child: LocalNode def child: LocalNode
override def children: Seq[LocalNode] = Seq(child) override def children: Seq[LocalNode] = Seq(child)
} }
abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
def left: LocalNode
def right: LocalNode
override def children: Seq[LocalNode] = Seq(left, right)
}
/**
* An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface.
*/
private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] {
private var nextRow: InternalRow = _
override def hasNext: Boolean = {
if (nextRow == null) {
val res = localNode.next()
if (res) {
nextRow = localNode.fetch()
}
res
} else {
true
}
}
override def next(): InternalRow = {
if (hasNext) {
val res = nextRow
nextRow = null
res
} else {
throw new NoSuchElementException
}
}
}
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression}
case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode)
extends UnaryLocalNode(conf) {
private[this] var project: UnsafeProjection = _ private[this] var project: UnsafeProjection = _
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
/** /**
* An operator that scans some local data collection in the form of Scala Seq. * An operator that scans some local data collection in the form of Scala Seq.
*/ */
case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow])
extends LeafLocalNode(conf) {
private[this] var iterator: Iterator[InternalRow] = _ private[this] var iterator: Iterator[InternalRow] = _
private[this] var currentRow: InternalRow = _ private[this] var currentRow: InternalRow = _
......
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.local package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
case class UnionNode(children: Seq[LocalNode]) extends LocalNode { case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) {
override def output: Seq[Attribute] = children.head.output override def output: Seq[Attribute] = children.head.output
......
...@@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
val condition = (testData.col("key") % 2) === 0 val condition = (testData.col("key") % 2) === 0
checkAnswer( checkAnswer(
testData, testData,
node => FilterNode(condition.expr, node), node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect() testData.filter(condition).collect()
) )
} }
...@@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
val condition = (emptyTestData.col("key") % 2) === 0 val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer( checkAnswer(
emptyTestData, emptyTestData,
node => FilterNode(condition.expr, node), node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect() emptyTestData.filter(condition).collect()
) )
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.execution.joins
class HashJoinNodeSuite extends LocalNodeTest {
import testImplicits._
private def wrapForUnsafe(
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
if (conf.unsafeEnabled) {
(left: LocalNode, right: LocalNode) => {
val _left = ConvertToUnsafeNode(conf, left)
val _right = ConvertToUnsafeNode(conf, right)
val r = f(_left, _right)
ConvertToSafeNode(conf, r)
}
} else {
f
}
}
def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
test(s"$suiteName: inner join with one match per row") {
withSQLConf(confPairs: _*) {
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(upperCaseData.col("N").expr),
Seq(lowerCaseData.col("n").expr),
joins.BuildLeft,
node1,
node2)
),
upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
)
}
}
test(s"$suiteName: inner join with multiple matches") {
withSQLConf(confPairs: _*) {
val x = testData2.where($"a" === 1).as("x")
val y = testData2.where($"a" === 1).as("y")
checkAnswer2(
x,
y,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(x.col("a").expr),
Seq(y.col("a").expr),
joins.BuildLeft,
node1,
node2)
),
x.join(y).where($"x.a" === $"y.a").collect()
)
}
}
test(s"$suiteName: inner join, no matches") {
withSQLConf(confPairs: _*) {
val x = testData2.where($"a" === 1).as("x")
val y = testData2.where($"a" === 2).as("y")
checkAnswer2(
x,
y,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(x.col("a").expr),
Seq(y.col("a").expr),
joins.BuildLeft,
node1,
node2)
),
Nil
)
}
}
test(s"$suiteName: big inner join, 4 matches per row") {
withSQLConf(confPairs: _*) {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
val bigDataX = bigData.as("x")
val bigDataY = bigData.as("y")
checkAnswer2(
bigDataX,
bigDataY,
wrapForUnsafe(
(node1, node2) =>
HashJoinNode(
conf,
Seq(bigDataX.col("key").expr),
Seq(bigDataY.col("key").expr),
joins.BuildLeft,
node1,
node2)
),
bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
}
}
}
joinSuite(
"general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
}
...@@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
test("basic") { test("basic") {
checkAnswer( checkAnswer(
testData, testData,
node => LimitNode(10, node), node => LimitNode(conf, 10, node),
testData.limit(10).collect() testData.limit(10).collect()
) )
} }
...@@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
test("empty") { test("empty") {
checkAnswer( checkAnswer(
emptyTestData, emptyTestData,
node => LimitNode(10, node), node => LimitNode(conf, 10, node),
emptyTestData.limit(10).collect() emptyTestData.limit(10).collect()
) )
} }
......
...@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local ...@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
class LocalNodeTest extends SparkFunSuite { class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
def conf: SQLConf = sqlContext.conf
/** /**
* Runs the LocalNode and makes sure the answer matches the expected result. * Runs the LocalNode and makes sure the answer matches the expected result.
...@@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite { ...@@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite {
protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
new SeqScanNode( new SeqScanNode(
conf,
df.queryExecution.sparkPlan.output, df.queryExecution.sparkPlan.output,
df.queryExecution.toRdd.map(_.copy()).collect()) df.queryExecution.toRdd.map(_.copy()).collect())
} }
......
...@@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
val columns = Seq(output(1), output(0)) val columns = Seq(output(1), output(0))
checkAnswer( checkAnswer(
testData, testData,
node => ProjectNode(columns, node), node => ProjectNode(conf, columns, node),
testData.select("value", "key").collect() testData.select("value", "key").collect()
) )
} }
...@@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
val columns = Seq(output(1), output(0)) val columns = Seq(output(1), output(0))
checkAnswer( checkAnswer(
emptyTestData, emptyTestData,
node => ProjectNode(columns, node), node => ProjectNode(conf, columns, node),
emptyTestData.select("value", "key").collect() emptyTestData.select("value", "key").collect()
) )
} }
......
...@@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
checkAnswer2( checkAnswer2(
testData, testData,
testData, testData,
(node1, node2) => UnionNode(Seq(node1, node2)), (node1, node2) => UnionNode(conf, Seq(node1, node2)),
testData.unionAll(testData).collect() testData.unionAll(testData).collect()
) )
} }
...@@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
checkAnswer2( checkAnswer2(
emptyTestData, emptyTestData,
emptyTestData, emptyTestData,
(node1, node2) => UnionNode(Seq(node1, node2)), (node1, node2) => UnionNode(conf, Seq(node1, node2)),
emptyTestData.unionAll(emptyTestData).collect() emptyTestData.unionAll(emptyTestData).collect()
) )
} }
...@@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { ...@@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
emptyTestData, emptyTestData, testData, emptyTestData) emptyTestData, emptyTestData, testData, emptyTestData)
doCheckAnswer( doCheckAnswer(
dfs, dfs,
nodes => UnionNode(nodes), nodes => UnionNode(conf, nodes),
dfs.reduce(_.unionAll(_)).collect() dfs.reduce(_.unionAll(_)).collect()
) )
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment