Skip to content
Snippets Groups Projects
Commit 1cce1a3b authored by 10129659's avatar 10129659 Committed by gatorsmile
Browse files

[SPARK-21603][SQL] The wholestage codegen will be much slower then that is...

[SPARK-21603][SQL] The wholestage codegen will be much slower then that is closed when the function is too long

## What changes were proposed in this pull request?
Close the whole stage codegen when the function lines is longer than the maxlines which will be setted by
spark.sql.codegen.MaxFunctionLength parameter, because when the function is too long , it will not get the JIT  optimizing.
A benchmark test result is 10x slower when the generated function is too long :

ignore("max function length of wholestagecodegen") {
    val N = 20 << 15

    val benchmark = new Benchmark("max function length of wholestagecodegen", N)
    def f(): Unit = sparkSession.range(N)
      .selectExpr(
        "id",
        "(id & 1023) as k1",
        "cast(id & 1023 as double) as k2",
        "cast(id & 1023 as int) as k3",
        "case when id > 100 and id <= 200 then 1 else 0 end as v1",
        "case when id > 200 and id <= 300 then 1 else 0 end as v2",
        "case when id > 300 and id <= 400 then 1 else 0 end as v3",
        "case when id > 400 and id <= 500 then 1 else 0 end as v4",
        "case when id > 500 and id <= 600 then 1 else 0 end as v5",
        "case when id > 600 and id <= 700 then 1 else 0 end as v6",
        "case when id > 700 and id <= 800 then 1 else 0 end as v7",
        "case when id > 800 and id <= 900 then 1 else 0 end as v8",
        "case when id > 900 and id <= 1000 then 1 else 0 end as v9",
        "case when id > 1000 and id <= 1100 then 1 else 0 end as v10",
        "case when id > 1100 and id <= 1200 then 1 else 0 end as v11",
        "case when id > 1200 and id <= 1300 then 1 else 0 end as v12",
        "case when id > 1300 and id <= 1400 then 1 else 0 end as v13",
        "case when id > 1400 and id <= 1500 then 1 else 0 end as v14",
        "case when id > 1500 and id <= 1600 then 1 else 0 end as v15",
        "case when id > 1600 and id <= 1700 then 1 else 0 end as v16",
        "case when id > 1700 and id <= 1800 then 1 else 0 end as v17",
        "case when id > 1800 and id <= 1900 then 1 else 0 end as v18")
      .groupBy("k1", "k2", "k3")
      .sum()
      .collect()

    benchmark.addCase(s"codegen = F") { iter =>
      sparkSession.conf.set("spark.sql.codegen.wholeStage", "false")
      f()
    }

    benchmark.addCase(s"codegen = T") { iter =>
      sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
      sparkSession.conf.set("spark.sql.codegen.MaxFunctionLength", "10000")
      f()
    }

    benchmark.run()

    /*
    Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1
    Intel64 Family 6 Model 58 Stepping 9, GenuineIntel
    max function length of wholestagecodegen: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
    ------------------------------------------------------------------------------------------------
    codegen = F                                    443 /  507          1.5         676.0       1.0X
    codegen = T                                   3279 / 3283          0.2        5002.6       0.1X
     */
  }

## How was this patch tested?
Run the unit test

Author: 10129659 <chen.yanshan@zte.com.cn>

Closes #18810 from eatoncys/codegen.
parent adf005da
No related branches found
No related tags found
No related merge requests found
......@@ -89,6 +89,14 @@ object CodeFormatter {
}
new CodeAndComment(code.result().trim(), map)
}
def stripExtraNewLinesAndComments(input: String): String = {
val commentReg =
("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/
"""([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment
val codeWithoutComment = commentReg.replaceAllIn(input, "")
codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines
}
}
private class CodeFormatter {
......
......@@ -355,6 +355,20 @@ class CodegenContext {
*/
private val placeHolderToComments = new mutable.HashMap[String, String]
/**
* It will count the lines of every Java function generated by whole-stage codegen,
* if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction,
* it will return true.
*/
def isTooLongGeneratedFunction: Boolean = {
classFunctions.values.exists { _.values.exists {
code =>
val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code)
codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction
}
}
}
/**
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
......
......@@ -572,6 +572,16 @@ object SQLConf {
"disable logging or -1 to apply no limit.")
.createWithDefault(1000)
val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction")
.internal()
.doc("The maximum lines of a single Java function generated by whole-stage codegen. " +
"When the generated function exceeds this threshold, " +
"the whole-stage codegen is deactivated for this subtree of the current query plan. " +
"The default value 2667 is the max length of byte code JIT supported " +
"for a single function(8000) divided by 3.")
.intConf
.createWithDefault(2667)
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
......@@ -1037,6 +1047,8 @@ class SQLConf extends Serializable with Logging {
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION)
def tableRelationCacheSize: Int =
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
......
......@@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite {
assert(reducedCode.body === "/*project_c4*/")
}
test("removing extra new lines and comments") {
val code =
"""
|/*
| * multi
| * line
| * comments
| */
|
|public function() {
|/*comment*/
| /*comment_with_space*/
|code_body
|//comment
|code_body
| //comment_with_space
|
|code_body
|}
""".stripMargin
val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code)
assert(reducedCode ===
"""
|public function() {
|code_body
|code_body
|code_body
|}
""".stripMargin)
}
testCase("basic example") {
"""
|class A {
......
......@@ -370,6 +370,14 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
override def doExecute(): RDD[InternalRow] = {
val (ctx, cleanedSource) = doCodeGen()
if (ctx.isTooLongGeneratedFunction) {
logWarning("Found too long generated codes and JIT optimization might not work, " +
"Whole-stage codegen disabled for this plan, " +
"You can change the config spark.sql.codegen.MaxFunctionLength " +
"to adjust the function length limit:\n "
+ s"$treeString")
return child.execute()
}
// try to compile and fallback if it failed
try {
CodeGenerator.compile(cleanedSource)
......
......@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
......@@ -149,4 +150,60 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
assert(df.collect() === Array(Row(1), Row(2)))
}
}
def genGroupByCodeGenContext(caseNum: Int): CodegenContext = {
val caseExp = (1 to caseNum).map { i =>
s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i"
}.toList
val keyExp = List(
"id",
"(id & 1023) as k1",
"cast(id & 1023 as double) as k2",
"cast(id & 1023 as int) as k3")
val ds = spark.range(10)
.selectExpr(keyExp:::caseExp: _*)
.groupBy("k1", "k2", "k3")
.sum()
val plan = ds.queryExecution.executedPlan
val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => wp.child match {
case hp: HashAggregateExec if (hp.child.isInstanceOf[ProjectExec]) => true
case _ => false
}
case _ => false
})
assert(wholeStageCodeGenExec.isDefined)
wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1
}
test("SPARK-21603 check there is a too long generated function") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") {
val ctx = genGroupByCodeGenContext(30)
assert(ctx.isTooLongGeneratedFunction === true)
}
}
test("SPARK-21603 check there is not a too long generated function") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") {
val ctx = genGroupByCodeGenContext(1)
assert(ctx.isTooLongGeneratedFunction === false)
}
}
test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) {
val ctx = genGroupByCodeGenContext(30)
assert(ctx.isTooLongGeneratedFunction === false)
}
}
test("SPARK-21603 check there is a too long generated function when threshold is 0") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") {
val ctx = genGroupByCodeGenContext(1)
assert(ctx.isTooLongGeneratedFunction === true)
}
}
}
......@@ -301,6 +301,68 @@ class AggregateBenchmark extends BenchmarkBase {
*/
}
ignore("max function length of wholestagecodegen") {
val N = 20 << 15
val benchmark = new Benchmark("max function length of wholestagecodegen", N)
def f(): Unit = sparkSession.range(N)
.selectExpr(
"id",
"(id & 1023) as k1",
"cast(id & 1023 as double) as k2",
"cast(id & 1023 as int) as k3",
"case when id > 100 and id <= 200 then 1 else 0 end as v1",
"case when id > 200 and id <= 300 then 1 else 0 end as v2",
"case when id > 300 and id <= 400 then 1 else 0 end as v3",
"case when id > 400 and id <= 500 then 1 else 0 end as v4",
"case when id > 500 and id <= 600 then 1 else 0 end as v5",
"case when id > 600 and id <= 700 then 1 else 0 end as v6",
"case when id > 700 and id <= 800 then 1 else 0 end as v7",
"case when id > 800 and id <= 900 then 1 else 0 end as v8",
"case when id > 900 and id <= 1000 then 1 else 0 end as v9",
"case when id > 1000 and id <= 1100 then 1 else 0 end as v10",
"case when id > 1100 and id <= 1200 then 1 else 0 end as v11",
"case when id > 1200 and id <= 1300 then 1 else 0 end as v12",
"case when id > 1300 and id <= 1400 then 1 else 0 end as v13",
"case when id > 1400 and id <= 1500 then 1 else 0 end as v14",
"case when id > 1500 and id <= 1600 then 1 else 0 end as v15",
"case when id > 1600 and id <= 1700 then 1 else 0 end as v16",
"case when id > 1700 and id <= 1800 then 1 else 0 end as v17",
"case when id > 1800 and id <= 1900 then 1 else 0 end as v18")
.groupBy("k1", "k2", "k3")
.sum()
.collect()
benchmark.addCase(s"codegen = F") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "false")
f()
}
benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000")
f()
}
benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500")
f()
}
benchmark.run()
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1
Intel64 Family 6 Model 58 Stepping 9, GenuineIntel
max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
----------------------------------------------------------------------------------------------
codegen = F 462 / 533 1.4 704.4 1.0X
codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X
codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X
*/
}
ignore("cube") {
val N = 5 << 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment