diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 1b13c8fd22cb17b174262efe54eafbef2ffe7560..da3ee46b7d1b679950408525e57e4dd227870cd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -297,7 +297,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegen.PIPELINE_DURATION_METRIC)) - override def doExecute(): RDD[InternalRow] = { + /** + * Generates code for this subtree. + * + * @return the tuple of the codegen context and the actual generated source. + */ + def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) val references = ctx.references.toArray @@ -334,6 +339,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val cleanedSource = CodeFormatter.stripExtraNewLines(source) logDebug(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) + (ctx, cleanedSource) + } + + override def doExecute(): RDD[InternalRow] = { + val (ctx, cleanedSource) = doCodeGen() + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5e573b315931175f0e61a9ff6f8a350aeb230af0..9916482a6804f1f015b3c569dfa085bae771ff73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf @@ -41,6 +41,13 @@ import org.apache.spark.sql.internal.SQLConf */ package object debug { + /** Helper function to evade the println() linter. */ + private def debugPrint(msg: String): Unit = { + // scalastyle:off println + println(msg) + // scalastyle:on println + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -62,12 +69,41 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - logDebug(s"Results returned: ${debugPlan.execute().count()}") + debugPrint(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => } } + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def debugCodegen(): Unit = { + debugPrint(debugCodegenString()) + } + + /** Visible for testing. */ + def debugCodegenString(): String = { + val plan = query.queryExecution.executedPlan + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { @@ -99,11 +135,11 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - logDebug(s"== ${child.simpleString} ==") - logDebug(s"Tuples output: ${tupleCount.value}") + debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 22189477d277dc6dca189bd2e08b36ea70926d9a..979265e274214b3a583680c23aa05abf9b0196bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -25,4 +25,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { test("DataFrame.debug()") { testData.debug() } + + test("debugCodegen") { + val res = sqlContext.range(10).groupBy("id").count().debugCodegenString() + assert(res.contains("Subtree 1 / 2")) + assert(res.contains("Subtree 2 / 2")) + assert(res.contains("Object[]")) + } }