diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8ddac19bf1b57bc418b5bf6343148c4f9f49efab..05c5e2f4cd77bda29fd5b033dd4eeee137a7aa85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -63,45 +63,16 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = HashJoin.rewriteKeyExpr(rightKeys) + .map(BindReferences.bindReference(_, right.output)) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } - /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ - private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width -= bits - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } - } - keyExpr :: Nil - } + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(buildKeys) @@ -247,3 +218,31 @@ trait HashJoin { } } } + +object HashJoin { + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + assert(keys.nonEmpty) + // TODO: support BooleanType, DateType and TimestampType + if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) + || keys.map(_.dataType.defaultSize).sum > 8) { + return keys + } + + var keyExpr: Expression = if (keys.head.dataType != LongType) { + Cast(keys.head, LongType) + } else { + keys.head + } + keys.tail.foreach { e => + val bits = e.dataType.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 97adffa8ce1012f91df160278b5b3d42d0ddcc67..83db81ea3f1c268859424afdc66811cdf366dda1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,11 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{LongType, ShortType} /** * Test various broadcast join operators. @@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { cases.foreach(assertBroadcastJoin) } } + + test("join key rewritten") { + val l = Literal(1L) + val i = Literal(2) + val s = Literal.create(3, ShortType) + val ss = Literal("hello") + + assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === + BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), + BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === + s :: s :: s :: s :: s :: Nil) + + assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) + } }