From 7e24249af1e2f896328ef0402fa47db78cb6f9ec Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Tue, 10 Feb 2015 19:50:44 -0800
Subject: [PATCH] [SQL][DataFrame] Fix column computability bug.

Do not recursively strip out projects. Only strip the first level project.

```scala
df("colA") + df("colB").as("colC")
```

Previously, the above would construct an invalid plan.

Author: Reynold Xin <rxin@databricks.com>

Closes #4519 from rxin/computability and squashes the following commits:

87ff763 [Reynold Xin] Code review feedback.
015c4fc [Reynold Xin] [SQL][DataFrame] Fix column computability.
---
 .../MatrixFactorizationModel.scala            |  2 +-
 .../scala/org/apache/spark/sql/Column.scala   | 35 ++++++++++++++-----
 .../org/apache/spark/sql/SQLContext.scala     |  4 +--
 .../spark/sql/ColumnExpressionSuite.scala     | 13 +++++--
 4 files changed, 39 insertions(+), 15 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 9ff06ac362..16979c9ed4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -180,7 +180,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
     def save(model: MatrixFactorizationModel, path: String): Unit = {
       val sc = model.userFeatures.sparkContext
       val sqlContext = new SQLContext(sc)
-      import sqlContext.implicits.createDataFrame
+      import sqlContext.implicits._
       val metadata = (thisClassName, thisFormatVersion, model.rank)
       val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
       metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b0e95908ee..9d5d6e78bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -66,27 +66,44 @@ trait Column extends DataFrame {
    */
   def isComputable: Boolean
 
+  /** Removes the top project so we can get to the underlying plan. */
+  private def stripProject(p: LogicalPlan): LogicalPlan = p match {
+    case Project(_, child) => child
+    case p => sys.error("Unexpected logical plan (expected Project): " + p)
+  }
+
   private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
-    val plan = Project(Seq(expr match {
+    val namedExpr = expr match {
       case named: NamedExpression => named
       case unnamed: Expression => Alias(unnamed, "col")()
-    }), baseCol.plan)
+    }
+    val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
     Column(baseCol.sqlContext, plan, expr)
   }
 
+  /**
+   * Construct a new column based on the expression and the other column value.
+   *
+   * There are two cases that can happen here:
+   * If otherValue is a constant, it is first turned into a Column.
+   * If otherValue is a Column, then:
+   *   - If this column and otherValue are both computable and come from the same logical plan,
+   *     then we can construct a ComputableColumn by applying a Project on top of the base plan.
+   *   - If this column is not computable, but otherValue is computable, then we can construct
+   *     a ComputableColumn based on otherValue's base plan.
+   *   - If this column is computable, but otherValue is not, then we can construct a
+   *     ComputableColumn based on this column's base plan.
+   *   - If neither columns are computable, then we create an IncomputableColumn.
+   */
   private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
-    // Removes all the top level projection and subquery so we can get to the underlying plan.
-    @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
-      case Project(_, child) => stripProject(child)
-      case Subquery(_, child) => stripProject(child)
-      case _ => p
-    }
-
+    // lit(otherValue) returns a Column always.
     (this, lit(otherValue)) match {
       case (left: ComputableColumn, right: ComputableColumn) =>
         if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
           computableCol(right, newExpr(right))
         } else {
+          // We don't want to throw an exception here because "df1("a") === df2("b")" can be
+          // a valid expression for join conditions, even though standalone they are not valid.
           Column(newExpr(right))
         }
       case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 523911d108..05ac1623d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
      *
      * @group userf
      */
-    implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+    implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
       self.createDataFrame(rdd)
     }
 
     /**
      * Creates a DataFrame from a local Seq of Product.
      */
-    implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+    implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
       self.createDataFrame(data)
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1d71039872..e3e6f652ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.Dsl._
 import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
 import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
 
 
@@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest {
     shouldBeComputable(-testData2("a"))
     shouldBeComputable(!testData2("a"))
 
-    shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
-    shouldBeComputable(
+    shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
+    shouldNotBeComputable(
       testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
-    shouldBeComputable(
+    shouldNotBeComputable(
       testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
 
     // Literals and unresolved columns should not be computable.
@@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest {
     shouldNotBeComputable(sum(testData2("a")))
   }
 
+  test("collect on column produced by a binary operator") {
+    val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+    checkAnswer(df("a") + df("b"), Seq(Row(3)))
+    checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
+  }
+
   test("star") {
     checkAnswer(testData.select($"*"), testData.collect().toSeq)
   }
-- 
GitLab