Skip to content
Snippets Groups Projects
Commit a814eeac authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Herman van Hovell
Browse files

[SPARK-18125][SQL] Fix a compilation error in codegen due to splitExpression

## What changes were proposed in this pull request?

As reported in the jira, sometimes the generated java code in codegen will cause compilation error.

Code snippet to test it:

    case class Route(src: String, dest: String, cost: Int)
    case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])

    val ds = sc.parallelize(Array(
      Route("a", "b", 1),
      Route("a", "b", 2),
      Route("a", "c", 2),
      Route("a", "d", 10),
      Route("b", "a", 1),
      Route("b", "a", 5),
      Route("b", "c", 6))
    ).toDF.as[Route]

    val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
      .groupByKey(r => (r.src, r.dest))
      .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
        GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
      }.map(_._2)

The problem here is, in `ReferenceToExpressions` we evaluate the children vars to local variables. Then the result expression is evaluated to use those children variables. In the above case, the result expression code is too long and will be split by `CodegenContext.splitExpression`. So those local variables cannot be accessed and cause compilation error.

## How was this patch tested?

Jenkins tests.

Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #15693 from viirya/fix-codege-compilation-error.
parent 57626a55
No related branches found
No related tags found
No related merge requests found
......@@ -63,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childrenGen = children.map(_.genCode(ctx))
val childrenVars = childrenGen.zip(children).map {
case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
}
val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
case (childGen, child) =>
// SPARK-18125: The children vars are local variables. If the result expression uses
// splitExpression, those variables cannot be accessed so compilation fails.
// To fix it, we use class variables to hold those local variables.
val classChildVarName = ctx.freshName("classChildVar")
val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
ctx.addMutableState("boolean", classChildVarIsNull, "")
val classChildVar =
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
s"${classChildVar.isNull} = ${childGen.isNull};"
(classChildVar, initCode)
}.unzip
val resultGen = result.transform {
case b: BoundReference => childrenVars(b.ordinal)
case b: BoundReference => classChildrenVars(b.ordinal)
}.genCode(ctx)
ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
isNull = resultGen.isNull, value = resultGen.value)
ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
}
}
......@@ -923,6 +923,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
}
test("SPARK-18125: Spark generated code causes CompileException") {
val data = Array(
Route("a", "b", 1),
Route("a", "b", 2),
Route("a", "c", 2),
Route("a", "d", 10),
Route("b", "a", 1),
Route("b", "a", 5),
Route("b", "c", 6))
val ds = sparkContext.parallelize(data).toDF.as[Route]
val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
.groupByKey(r => (r.src, r.dest))
.reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
}.map(_._2)
val expected = Seq(
GroupedRoutes("a", "d", Seq(Route("a", "d", 10))),
GroupedRoutes("b", "c", Seq(Route("b", "c", 6))),
GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))),
GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))),
GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
)
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
x.toString.compareTo(y.toString)
}
}
checkDatasetUnorderly(grped, expected: _*)
}
test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
val resultValue = 12345
val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
......@@ -1071,3 +1105,6 @@ object DatasetTransform {
ds.map(_ + 1)
}
}
case class Route(src: String, dest: String, cost: Int)
case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment