From 3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Sat, 16 Jan 2016 10:29:27 -0800 Subject: [PATCH] [SPARK-12796] [SQL] Whole stage codegen This is the initial work for whole stage codegen, it support Projection/Filter/Range, we will continue work on this to support more physical operators. A micro benchmark show that a query with range, filter and projection could be 3X faster then before. It's turned on by default. For a tree that have at least two chained plans, a WholeStageCodegen will be inserted into it, for example, the following plan ``` Limit 10 +- Project [(id#5L + 1) AS (id + 1)#6L] +- Filter ((id#5L & 1) = 1) +- Range 0, 1, 4, 10, [id#5L] ``` will be translated into ``` Limit 10 +- WholeStageCodegen +- Project [(id#1L + 1) AS (id + 1)#2L] +- Filter ((id#1L & 1) = 1) +- Range 0, 1, 4, 10, [id#1L] ``` Here is the call graph to generate Java source for A and B (A support codegen, but B does not): ``` * WholeStageCodegen Plan A FakeInput Plan B * ========================================================================= * * -> execute() * | * doExecute() --------> produce() * | * doProduce() -------> produce() * | * doProduce() ---> execute() * | * consume() * doConsume() ------------| * | * doConsume() <----- consume() ``` A SparkPlan that support codegen need to implement doProduce() and doConsume(): ``` def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String ``` Author: Davies Liu <davies@databricks.com> Closes #10735 from davies/whole2. --- .../catalyst/expressions/BoundAttribute.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 76 +++-- .../expressions/codegen/CodegenFallback.scala | 8 +- .../codegen/GenerateMutableProjection.scala | 8 +- .../codegen/GenerateOrdering.scala | 8 +- .../codegen/GeneratePredicate.scala | 8 +- .../codegen/GenerateSafeProjection.scala | 8 +- .../codegen/GenerateUnsafeProjection.scala | 10 +- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../expressions/conditionalExpressions.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 3 +- .../spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 3 - .../scala/org/apache/spark/sql/SQLConf.scala | 9 + .../org/apache/spark/sql/SQLContext.scala | 3 +- .../sql/execution/BufferedRowIterator.java | 64 ++++ .../spark/sql/execution/SparkPlan.scala | 1 - .../sql/execution/WholeStageCodegen.scala | 299 ++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 114 ++++++- .../columnar/GenerateColumnAccessor.scala | 6 +- .../apache/spark/sql/CachedTableSuite.scala | 6 +- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 +- .../BenchmarkWholeStageCodegen.scala | 60 ++++ .../spark/sql/execution/PlannerSuite.scala | 6 +- .../execution/WholeStageCodegenSuite.scala | 38 +++ .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../spark/sql/hive/CachedTableSuite.scala | 6 +- .../hive/execution/HiveComparisonTest.scala | 2 +- .../execution/HiveTypeCoercionSuite.scala | 2 +- .../sql/hive/execution/PruningSuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 8 +- .../ParquetHadoopFsRelationSuite.scala | 2 +- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 37 files changed, 694 insertions(+), 107 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index dda822d054..4727ff1885 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - if (nullable) { + if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { + ev.isNull = ctx.currentVars(ordinal).isNull + ev.value = ctx.currentVars(ordinal).value + "" + } else if (nullable) { s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3a39a0e75..683029ff14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,12 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Holding a list of generated columns as input of current operator, will be used by + * BoundReference to generate code. + */ + var currentVars: Seq[ExprCode] = null + /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. @@ -77,6 +83,16 @@ class CodegenContext { mutableStates += ((javaType, variableName, initCode)) } + def declareMutableStates(): String = { + mutableStates.map { case (javaType, variableName, _) => + s"private $javaType $variableName;" + }.mkString("\n") + } + + def initMutableStates(): String = { + mutableStates.map(_._3).mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ @@ -111,6 +127,10 @@ class CodegenContext { // The collection of sub-exression result resetting methods that need to be called on each row. val subExprResetVariables = mutable.ArrayBuffer.empty[String] + def declareAddedFunctions(): String = { + addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -120,7 +140,7 @@ class CodegenContext { final val JAVA_DOUBLE = "double" /** The variable name of the input row in generated code. */ - final val INPUT_ROW = "i" + final var INPUT_ROW = "i" private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -476,20 +496,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodegenContext): String = { - ctx.mutableStates.map { case (javaType, variableName, _) => - s"private $javaType $variableName;" - }.mkString("\n") - } - - protected def initMutableStates(ctx: CodegenContext): String = { - ctx.mutableStates.map(_._3).mkString("\n") - } - - protected def declareAddedFunctions(ctx: CodegenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim - } - /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -505,16 +511,33 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = create(canonicalize(expressions)) + /** - * Compile the Java source code into a Java class, using Janino. + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen */ - protected def compile(code: String): GeneratedClass = { + def newCodeGenContext(): CodegenContext = { + new CodegenContext + } +} + +object CodeGenerator extends Logging { + /** + * Compile the Java source code into a Java class, using Janino. + */ + def compile(code: String): GeneratedClass = { cache.get(code) } /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) @@ -577,19 +600,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin result } }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) - - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = create(canonicalize(expressions)) - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodegenContext = { - new CodegenContext - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index cface21e5f..f58a2daf90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -30,13 +30,15 @@ trait CodegenFallback extends Expression { case _ => } + // LeafNode does not need `input` + val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") if (nullable) { s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { @@ -47,7 +49,7 @@ trait CodegenFallback extends Expression { ev.isNull = "false" s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 63d13a8b87..59ef0f5836 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -107,13 +107,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private Object[] references; private MutableRow mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificMutableProjection(Object[] references) { this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { @@ -138,7 +138,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) () => { c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e033f62170..6de57537ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -118,12 +118,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private Object[] references; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificOrdering(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public int compare(InternalRow a, InternalRow b) { @@ -135,6 +135,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") - compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 6fbe12fc65..58065d956f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -47,12 +47,12 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final Object[] references; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificPredicate(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 10bd9c6103..e750ad9c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -160,13 +160,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private Object[] references; private MutableRow mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificSafeProjection(Object[] references) { this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public java.lang.Object apply(java.lang.Object _i) { @@ -179,7 +179,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1a0565a8eb..61e7469ee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -338,14 +338,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { private Object[] references; - - ${declareMutableStates(ctx)} - - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificUnsafeProjection(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } // Scala.Function1 need this @@ -362,7 +360,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 8781cc77f4..b1ffbaa3e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 2a24235a29..1eff2c4dd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -224,6 +224,7 @@ object CaseWhen { } } + /** * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". * When a = b, returns c; when a = d, returns e; else returns f. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 2c12de08f4..493e0aae01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -351,8 +351,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression val hasher = classOf[Murmur3_x86_32].getName def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = - ExprCode(code = "", isNull = "false", value = v) + def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) dataType match { case NullType => inlineValue(seed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d0b29aa01f..d74f3ef2ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -452,7 +452,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and * `lastChildren` for the root node should be empty. */ - protected def generateTreeString( + def generateTreeString( depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { if (depth > 0) { lastChildren.init.foreach { isLast => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3422d0ead4..95e5fbb119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag @@ -39,12 +38,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator -import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { new DataFrame(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4e3662724c..4c1eb0b30b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -489,6 +489,13 @@ private[spark] object SQLConf { isPublic = false, doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") + val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", + defaultValue = Some(true), + doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + + " method", + isPublic = false) + + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -561,6 +568,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) + private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) private[spark] def subexpressionEliminationEnabled: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a0939adb6d..18ddffe1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,7 +904,8 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)) + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java new file mode 100644 index 0000000000..b1bbb1da10 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import scala.collection.Iterator; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; + +/** + * An iterator interface used to pull the output from generated function for multiple operators + * (whole stage codegen). + * + * TODO: replaced it by batched columnar format. + */ +public class BufferedRowIterator { + protected InternalRow currentRow; + protected Iterator<InternalRow> input; + // used when there is no column in output + protected UnsafeRow unsafeRow = new UnsafeRow(0); + + public boolean hasNext() { + if (currentRow == null) { + processNext(); + } + return currentRow != null; + } + + public InternalRow next() { + InternalRow r = currentRow; + currentRow = null; + return r; + } + + public void setInput(Iterator<InternalRow> iter) { + input = iter; + } + + /** + * Processes the input until have a row as output (currentRow). + * + * After it's called, if currentRow is still null, it means no more rows left. + */ + protected void processNext() { + if (input.hasNext()) { + currentRow = input.next(); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2355de3d05..75101ea0fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,7 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala new file mode 100644 index 0000000000..c15fabab80 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * An interface for those physical operators that support codegen. + */ +trait CodegenSupport extends SparkPlan { + + /** + * Whether this SparkPlan support whole stage codegen or not. + */ + def supportCodegen: Boolean = true + + /** + * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. + */ + private var parent: CodegenSupport = null + + /** + * Returns an input RDD of InternalRow and Java source code to process them. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + this.parent = parent + doProduce(ctx) + } + + /** + * Generate the Java source code to process, should be overrided by subclass to support codegen. + * + * doProduce() usually generate the framework, for example, aggregation could generate this: + * + * if (!initialized) { + * # create a hash map, then build the aggregation hash map + * # call child.produce() + * initialized = true; + * } + * while (hashmap.hasNext()) { + * row = hashmap.next(); + * # build the aggregation results + * # create varialbles for results + * # call consume(), wich will call parent.doConsume() + * } + */ + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + + /** + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + */ + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) + } + + + /** + * Generate the Java source code to process the rows from child SparkPlan. + * + * This should be override by subclass to support codegen. + * + * For example, Filter will generate the code like this: + * + * # code to evaluate the predicate expression, result is isNull1 and value2 + * if (isNull1 || value2) { + * # call consume(), which will call parent.doConsume() + * } + */ + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String +} + + +/** + * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * + * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes + * an RDD iterator of InternalRow. + */ +case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { + + override def output: Seq[Attribute] = child.output + + override def supportCodegen: Boolean = true + + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns = exprs.map(_.gen(ctx)) + val code = s""" + | while (input.hasNext()) { + | InternalRow $row = (InternalRow) input.next(); + | ${columns.map(_.code).mkString("\n")} + | ${consume(ctx, columns)} + | } + """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + override def simpleString: String = "INPUT" +} + +/** + * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * function. + * + * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * + * WholeStageCodegen Plan A FakeInput Plan B + * ========================================================================= + * + * -> execute() + * | + * doExecute() --------> produce() + * | + * doProduce() -------> produce() + * | + * doProduce() ---> execute() + * | + * consume() + * doConsume() ------------| + * | + * doConsume() <----- consume() + * + * SparkPlan A should override doProduce() and doConsume(). + * + * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, + * used to generated code for BoundReference. + */ +case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) + extends SparkPlan with CodegenSupport { + + override def output: Seq[Attribute] = plan.output + + override def doExecute(): RDD[InternalRow] = { + val ctx = new CodegenContext + val (rdd, code) = plan.produce(ctx, this) + val references = ctx.references.toArray + val source = s""" + public Object generate(Object[] references) { + return new GeneratedIterator(references); + } + + class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + + private Object[] references; + ${ctx.declareMutableStates()} + + public GeneratedIterator(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + protected void processNext() { + $code + } + } + """ + // try to compile, helpful for debug + // println(s"${CodeFormatter.format(source)}") + CodeGenerator.compile(source) + + rdd.mapPartitions { iter => + val clazz = CodeGenerator.compile(source) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.setInput(iter) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } + } + } + + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + throw new UnsupportedOperationException + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } + } + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + + builder.append(simpleString) + builder.append("\n") + + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + + builder + } + + override def simpleString: String = "WholeStageCodegen" +} + + +/** + * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + */ +private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + + private def supportCodegen(plan: SparkPlan): Boolean = plan match { + case plan: CodegenSupport if plan.supportCodegen => + // Non-leaf with CodegenFallback does not work with whole stage codegen + val willFallback = plan.expressions.exists( + _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined + ) + // the generated code will be huge if there are too many columns + val haveManyColumns = plan.output.length > 200 + !willFallback && !haveManyColumns + case _ => false + } + + def apply(plan: SparkPlan): SparkPlan = { + if (sqlContext.conf.wholeStageEnabled) { + plan.transform { + case plan: CodegenSupport if supportCodegen(plan) && + // Whole stage codegen is only useful when there are at least two levels of operators that + // support it (save at least one projection/iterator). + plan.children.exists(supportCodegen) => + + var inputs = ArrayBuffer[SparkPlan]() + val combined = plan.transform { + case p if !supportCodegen(p) => + inputs += p + InputAdapter(p) + }.asInstanceOf[CodegenSupport] + WholeStageCodegen(combined, inputs) + } + } else { + plan + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 92c9a56131..9e2e0357c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -22,19 +22,37 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler -case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryNode with CodegenSupport { override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) override def output: Seq[Attribute] = projectList.map(_.toAttribute) + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + val exprs = projectList.map(x => + ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) + ctx.currentVars = input + val output = exprs.map(_.gen(ctx)) + s""" + | ${output.map(_.code).mkString("\n")} + | + | ${consume(ctx, output)} + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitionsInternal { iter => @@ -51,13 +69,30 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends } -case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { +case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output private[sql] override lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + val expr = ExpressionCanonicalizer.execute( + BindReferences.bindReference(condition, child.output)) + ctx.currentVars = input + val eval = expr.gen(ctx) + s""" + | ${eval.code} + | if (!${eval.isNull} && ${eval.value}) { + | ${consume(ctx, ctx.currentVars)} + | } + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") @@ -116,7 +151,80 @@ case class Range( numSlices: Int, numElements: BigInt, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with CodegenSupport { + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") + ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + val partitionEnd = ctx.freshName("range_partitionEnd") + ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") + val number = ctx.freshName("range_number") + ctx.addMutableState("long", number, s"$number = 0L;") + val overflow = ctx.freshName("range_overflow") + ctx.addMutableState("boolean", overflow, s"$overflow = false;") + + val value = ctx.freshName("range_value") + val ev = ExprCode("", "false", value) + val BigInt = classOf[java.math.BigInteger].getName + val checkEnd = if (step > 0) { + s"$number < $partitionEnd" + } else { + s"$number > $partitionEnd" + } + + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) + + val code = s""" + | // initialize Range + | if (!$initTerm) { + | $initTerm = true; + | if (input.hasNext()) { + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } else { + | return; + | } + | } + | + | while (!$overflow && $checkEnd) { + | long $value = $number; + | $number += ${step}L; + | if ($number < $value ^ ${step}L < 0) { + | $overflow = true; + | } + | ${consume(ctx, Seq(ev))} + | } + """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } protected override def doExecute(): RDD[InternalRow] = { sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 7888e34e8a..72eb1f6cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -143,14 +143,14 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private DataType[] columnTypes = null; private int[] columnIndexes = null; - ${declareMutableStates(ctx)} + ${ctx.declareMutableStates()} public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; this.mutableRow = new MutableUnsafeRow(rowWriter); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { @@ -190,6 +190,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") - compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 89b9a68768..e8d0678989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,12 +36,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = sqlContext.table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index eb4efcd1d4..b349bb6dc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -629,7 +629,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { - val projects = df.queryExecution.executedPlan.collect { + val projects = df.queryExecution.sparkPlan.collect { case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd..c17be8ace9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -123,15 +123,15 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // equijoin - should be converted into broadcast join - val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + val plan1 = df1.join(broadcast(df2), "key").queryExecution.sparkPlan assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) // no join key -- should not be a broadcast join - val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) // planner should not crash without a join - broadcast(df1).queryExecution.executedPlan + broadcast(df1).queryExecution.sparkPlan // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 75e81b9c91..bdb9421cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -247,7 +247,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { val df = sql(sqlText) // First, check if we have GeneratedAggregate. - val hasGeneratedAgg = df.queryExecution.executedPlan + val hasGeneratedAgg = df.queryExecution.sparkPlan .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { @@ -792,11 +792,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-11111 null-safe join should not use cartesian product") { val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") - val cp = df.queryExecution.executedPlan.collect { + val cp = df.queryExecution.sparkPlan.collect { case cp: CartesianProduct => cp } assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") - val smj = df.queryExecution.executedPlan.collect { + val smj = df.queryExecution.sparkPlan.collect { case smj: SortMergeJoin => smj } assert(smj.size > 0, "should use SortMergeJoin") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala new file mode 100644 index 0000000000..788b04fcf8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" + */ +class BenchmarkWholeStageCodegen extends SparkFunSuite { + def testWholeStage(values: Int): Unit = { + val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + val sc = SparkContext.getOrCreate(conf) + val sqlContext = SQLContext.getOrCreate(sc) + + val benchmark = new Benchmark("Single Int Column Scan", values) + + benchmark.addCase("Without whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).filter("(id & 1) = 1").count() + } + + benchmark.addCase("With whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).filter("(id & 1) = 1").count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Without whole stage codegen 6725.52 31.18 1.00 X + With whole stage codegen 2233.05 93.91 3.01 X + */ + benchmark.run() + } + + ignore("benchmark") { + testWholeStage(1024 * 1024 * 200) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 03a1b8e11d..49feeaf17d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -94,7 +94,7 @@ class PlannerSuite extends SharedSQLContext { """ |SELECT l.a, l.b |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan + """.stripMargin).queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } @@ -147,7 +147,7 @@ class PlannerSuite extends SharedSQLContext { val a = testData.as("a") val b = sqlContext.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } @@ -168,7 +168,7 @@ class PlannerSuite extends SharedSQLContext { sqlContext.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { - val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan + val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala new file mode 100644 index 0000000000..c54fc6ba2d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext + +class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { + + test("range/filter should be combined") { + val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val plan = df.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) + + checkThatPlansAgree( + sqlContext.range(100), + (p: SparkPlan) => + WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()), + (p: SparkPlan) => Filter('a == 1, p), + sortAnswers = false + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 25afed25c8..6e21d5a061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -48,7 +48,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index d762f7bfe9..647a7e9a4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -114,7 +114,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 58581d71e1..aee8e84db5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -62,7 +62,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.executedPlan + val plan = df3.queryExecution.sparkPlan assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 9b37dd1103..11863caffe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -30,12 +30,12 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { import hiveContext._ def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index fd3339a66b..2e0a8698e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -485,7 +485,7 @@ abstract class HiveComparisonTest val executions = queryList.map(new TestHive.QueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { - case (q, e) => e.executedPlan.collect { + case (q, e) => e.sparkPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => (q, e, i) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 5bd323ea09..d2f91861ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -43,7 +43,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { + val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: Project => e }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 210d566745..b91248bfb3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -144,7 +144,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.QueryExecution(sql).executedPlan + val plan = new TestHive.QueryExecution(sql).sparkPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index ed544c6380..c997453803 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -190,11 +190,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test(s"conversion is working") { assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { case _: HiveTableScan => true }.isEmpty) assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { case _: PhysicalRDD => true }.nonEmpty) } @@ -305,7 +305,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { + df.queryExecution.sparkPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + @@ -335,7 +335,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { + df.queryExecution.sparkPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index e866493ee6..ba2a483bba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -149,7 +149,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) val df = sqlContext.read.parquet(path).filter('a === 0).select('b) - val physicalPlan = df.queryExecution.executedPlan + val physicalPlan = df.queryExecution.sparkPlan assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 058c101eeb..9ab3e11609 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -156,9 +156,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") { val df = partitionedDF.where(filter).select(projections: _*) val queryExecution = df.queryExecution - val executedPlan = queryExecution.executedPlan + val sparkPlan = queryExecution.sparkPlan - val rawScan = executedPlan.collect { + val rawScan = sparkPlan.collect { case p: PhysicalRDD => p } match { case Seq(scan) => scan @@ -177,7 +177,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat assert(requiredColumns === SimpleTextRelation.requiredColumns) val nonPushedFilters = { - val boundFilters = executedPlan.collect { + val boundFilters = sparkPlan.collect { case f: execution.Filter => f } match { case Nil => Nil -- GitLab