From b8aec6cd236f09881cad0fff9a6f1a5692934e21 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Sat, 18 Jul 2015 11:08:18 -0700
Subject: [PATCH] [SPARK-9143] [SQL] Add planner rule for automatically
 inserting Unsafe <-> Safe row format converters

Now that we have two different internal row formats, UnsafeRow and the old Java-object-based row format, we end up having to perform conversions between these two formats. These conversions should not be performed by the operators themselves; instead, the planner should be responsible for inserting appropriate format conversions when they are needed.

This patch makes the following changes:

- Add two new physical operators for performing row format conversions, `ConvertToUnsafe` and `ConvertFromUnsafe`.
- Add new methods to `SparkPlan` to allow operators to express whether they output UnsafeRows and whether they can handle safe or unsafe rows as inputs.
- Implement an `EnsureRowFormats` rule to automatically insert converter operators where necessary.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7482 from JoshRosen/unsafe-converter-planning and squashes the following commits:

7450fa5 [Josh Rosen] Resolve conflicts in favor of choosing UnsafeRow
5220cce [Josh Rosen] Add roundtrip converter test
2bb8da8 [Josh Rosen] Add Union unsafe support + tests to bump up test coverage
6f79449 [Josh Rosen] Add even more assertions to execute()
08ce199 [Josh Rosen] Rename ConvertFromUnsafe -> ConvertToSafe
0e2d548 [Josh Rosen] Add assertion if operators' input rows are in different formats
cabb703 [Josh Rosen] Add tests for Filter
3b11ce3 [Josh Rosen] Add missing test file.
ae2195a [Josh Rosen] Fixes
0fef0f8 [Josh Rosen] Rename file.
d5f9005 [Josh Rosen] Finish writing EnsureRowFormats planner rule
b5df19b [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-converter-planning
9ba3038 [Josh Rosen] WIP
---
 .../org/apache/spark/sql/SQLContext.scala     |   9 +-
 .../spark/sql/execution/SparkPlan.scala       |  24 ++++
 .../spark/sql/execution/basicOperators.scala  |  11 ++
 .../sql/execution/rowFormatConverters.scala   | 107 ++++++++++++++++++
 .../execution/RowFormatConvertersSuite.scala  |  91 +++++++++++++++
 5 files changed, 239 insertions(+), 3 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
 create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 46bd60daa1..2dda3ad121 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -921,12 +921,15 @@ class SQLContext(@transient val sparkContext: SparkContext)
   protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1)
 
   /**
-   * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
+   * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal
+   * row format conversions as needed.
    */
   @transient
   protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
-    val batches =
-      Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
+    val batches = Seq(
+      Batch("Add exchange", Once, EnsureRequirements(self)),
+      Batch("Add row converters", Once, EnsureRowFormats)
+    )
   }
 
   protected[sql] def openSession(): SQLSession = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ba12056ee7..f363e9947d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -79,12 +79,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Product
   /** Specifies sort order for each partition requirements on the input data for this operator. */
   def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
 
+  /** Specifies whether this operator outputs UnsafeRows */
+  def outputsUnsafeRows: Boolean = false
+
+  /** Specifies whether this operator is capable of processing UnsafeRows */
+  def canProcessUnsafeRows: Boolean = false
+
+  /**
+   * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
+   * that are not UnsafeRows).
+   */
+  def canProcessSafeRows: Boolean = true
+
   /**
    * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
    * after adding query plan information to created RDDs for visualization.
    * Concrete implementations of SparkPlan should override doExecute instead.
    */
   final def execute(): RDD[InternalRow] = {
+    if (children.nonEmpty) {
+      val hasUnsafeInputs = children.exists(_.outputsUnsafeRows)
+      val hasSafeInputs = children.exists(!_.outputsUnsafeRows)
+      assert(!(hasSafeInputs && hasUnsafeInputs),
+        "Child operators should output rows in the same format")
+      assert(canProcessSafeRows || canProcessUnsafeRows,
+        "Operator must be able to process at least one row format")
+      assert(!hasSafeInputs || canProcessSafeRows,
+        "Operator will receive safe rows as input but cannot process safe rows")
+      assert(!hasUnsafeInputs || canProcessUnsafeRows,
+        "Operator will receive unsafe rows as input but cannot process unsafe rows")
+    }
     RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
       doExecute()
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4c063c299b..82bef269b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -64,6 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
   }
 
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+
+  override def canProcessUnsafeRows: Boolean = true
+
+  override def canProcessSafeRows: Boolean = true
 }
 
 /**
@@ -104,6 +110,9 @@ case class Sample(
 case class Union(children: Seq[SparkPlan]) extends SparkPlan {
   // TODO: attributes output by union should be distinct for nullability purposes
   override def output: Seq[Attribute] = children.head.output
+  override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows)
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = true
   protected override def doExecute(): RDD[InternalRow] =
     sparkContext.union(children.map(_.execute()))
 }
@@ -306,6 +315,8 @@ case class UnsafeExternalSort(
   override def output: Seq[Attribute] = child.output
 
   override def outputOrdering: Seq[SortOrder] = sortOrder
+
+  override def outputsUnsafeRows: Boolean = true
 }
 
 @DeveloperApi
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
new file mode 100644
index 0000000000..421d510e67
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * :: DeveloperApi ::
+ * Converts Java-object-based rows into [[UnsafeRow]]s.
+ */
+@DeveloperApi
+case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = child.output
+  override def outputsUnsafeRows: Boolean = true
+  override def canProcessUnsafeRows: Boolean = false
+  override def canProcessSafeRows: Boolean = true
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitions { iter =>
+      val convertToUnsafe = UnsafeProjection.create(child.schema)
+      iter.map(convertToUnsafe)
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Converts [[UnsafeRow]]s back into Java-object-based rows.
+ */
+@DeveloperApi
+case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = child.output
+  override def outputsUnsafeRows: Boolean = false
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = false
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitions { iter =>
+      val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType))
+      iter.map(convertToSafe)
+    }
+  }
+}
+
+private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
+
+  private def onlyHandlesSafeRows(operator: SparkPlan): Boolean =
+    operator.canProcessSafeRows && !operator.canProcessUnsafeRows
+
+  private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean =
+    operator.canProcessUnsafeRows && !operator.canProcessSafeRows
+
+  private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean =
+    operator.canProcessSafeRows && operator.canProcessUnsafeRows
+
+  override def apply(operator: SparkPlan): SparkPlan = operator.transformUp {
+    case operator: SparkPlan if onlyHandlesSafeRows(operator) =>
+      if (operator.children.exists(_.outputsUnsafeRows)) {
+        operator.withNewChildren {
+          operator.children.map {
+            c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
+          }
+        }
+      } else {
+        operator
+      }
+    case operator: SparkPlan if onlyHandlesUnsafeRows(operator) =>
+      if (operator.children.exists(!_.outputsUnsafeRows)) {
+        operator.withNewChildren {
+          operator.children.map {
+            c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+          }
+        }
+      } else {
+        operator
+      }
+    case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
+      if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
+        // If this operator's children produce both unsafe and safe rows, then convert everything
+        // to unsafe rows
+        operator.withNewChildren {
+          operator.children.map {
+            c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+          }
+        }
+      } else {
+        operator
+      }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
new file mode 100644
index 0000000000..7b75f75591
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.expressions.IsNull
+import org.apache.spark.sql.test.TestSQLContext
+
+class RowFormatConvertersSuite extends SparkPlanTest {
+
+  private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
+    case c: ConvertToUnsafe => c
+    case c: ConvertToSafe => c
+  }
+
+  private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
+  assert(!outputsSafe.outputsUnsafeRows)
+  private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
+  assert(outputsUnsafe.outputsUnsafeRows)
+
+  test("planner should insert unsafe->safe conversions when required") {
+    val plan = Limit(10, outputsUnsafe)
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
+  }
+
+  test("filter can process unsafe rows") {
+    val plan = Filter(IsNull(null), outputsUnsafe)
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(getConverters(preparedPlan).isEmpty)
+    assert(preparedPlan.outputsUnsafeRows)
+  }
+
+  test("filter can process safe rows") {
+    val plan = Filter(IsNull(null), outputsSafe)
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(getConverters(preparedPlan).isEmpty)
+    assert(!preparedPlan.outputsUnsafeRows)
+  }
+
+  test("execute() fails an assertion if inputs rows are of different formats") {
+    val e = intercept[AssertionError] {
+      Union(Seq(outputsSafe, outputsUnsafe)).execute()
+    }
+    assert(e.getMessage.contains("format"))
+  }
+
+  test("union requires all of its input rows' formats to agree") {
+    val plan = Union(Seq(outputsSafe, outputsUnsafe))
+    assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(preparedPlan.outputsUnsafeRows)
+  }
+
+  test("union can process safe rows") {
+    val plan = Union(Seq(outputsSafe, outputsSafe))
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(!preparedPlan.outputsUnsafeRows)
+  }
+
+  test("union can process unsafe rows") {
+    val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
+    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    assert(preparedPlan.outputsUnsafeRows)
+  }
+
+  test("round trip with ConvertToUnsafe and ConvertToSafe") {
+    val input = Seq(("hello", 1), ("world", 2))
+    checkAnswer(
+      TestSQLContext.createDataFrame(input),
+      plan => ConvertToSafe(ConvertToUnsafe(plan)),
+      input.map(Row.fromTuple)
+    )
+  }
+}
-- 
GitLab