diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d3a2249d7006c1d9882a9f7dac9a8daab381a6f0..6336dee7be6a37d2daf32aa48e4001be57e7a8c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -147,14 +147,6 @@ class DataFrame private[sql]( queryExecution.analyzed } - /** - * An implicit conversion function internal to this class for us to avoid doing - * "new DataFrame(...)" everywhere. - */ - @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } - protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -235,7 +227,7 @@ class DataFrame private[sql]( // For Data that has more than "numRows" records if (hasMoreData) { val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows ${rowsString}\n") + sb.append(s"only showing top $numRows $rowsString\n") } sb.toString() @@ -332,7 +324,7 @@ class DataFrame private[sql]( */ def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) - explain.queryExecution.executedPlan.executeCollect().foreach { + withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println @@ -370,7 +362,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = show(numRows, true) + def show(numRows: Int): Unit = show(numRows, truncate = true) /** * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters @@ -445,7 +437,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def join(right: DataFrame): DataFrame = { + def join(right: DataFrame): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } @@ -520,21 +512,25 @@ class DataFrame private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] // Project only one of the join columns. - val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) val condition = usingColumns.map { col => - catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + catalyst.expressions.EqualTo( + withPlan(joined.left).resolve(col), + withPlan(joined.right).resolve(col)) }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => catalyst.expressions.And(cond, eqTo) } - Project( - joined.output.filterNot(joinedCols.contains(_)), - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) + withPlan { + Project( + joined.output.filterNot(joinedCols.contains(_)), + Join( + joined.left, + joined.right, + joinType = JoinType(joinType), + condition) + ) + } } /** @@ -581,19 +577,20 @@ class DataFrame private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + val plan = withPlan( + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return plan + return withPlan(plan) } // If left/right have no output set intersection, return the plan. - val lanalyzed = this.logicalPlan.queryExecution.analyzed - val ranalyzed = right.logicalPlan.queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return plan + return withPlan(plan) } // Otherwise, find the trivially true predicates and automatically resolves them to both sides. @@ -602,9 +599,14 @@ class DataFrame private[sql]( val cond = plan.condition.map { _.transform { case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + catalyst.expressions.EqualTo( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} - plan.copy(condition = cond) + + withPlan { + plan.copy(condition = cond) + } } /** @@ -707,7 +709,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def as(alias: String): DataFrame = Subquery(alias, logicalPlan) + def as(alias: String): DataFrame = withPlan { + Subquery(alias, logicalPlan) + } /** * (Scala-specific) Returns a new [[DataFrame]] with an alias set. @@ -739,7 +743,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = { + def select(cols: Column*): DataFrame = withPlan { val namedExpressions = cols.map { // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to @@ -798,7 +802,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan) + def filter(condition: Column): DataFrame = withPlan { + Filter(condition.expr, logicalPlan) + } /** * Filters rows using the given SQL expression. @@ -1039,7 +1045,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + def limit(n: Int): DataFrame = withPlan { + Limit(Literal(n), logicalPlan) + } /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. @@ -1047,7 +1055,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + def unionAll(other: DataFrame): DataFrame = withPlan { + Union(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. @@ -1055,7 +1065,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + def intersect(other: DataFrame): DataFrame = withPlan { + Intersect(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. @@ -1063,7 +1075,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + def except(other: DataFrame): DataFrame = withPlan { + Except(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -1074,7 +1088,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan) } @@ -1102,7 +1116,7 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) }.toArray } @@ -1162,8 +1176,10 @@ class DataFrame private[sql]( f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } /** @@ -1190,8 +1206,10 @@ class DataFrame private[sql]( } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } ///////////////////////////////////////////////////////////////////////////// @@ -1309,7 +1327,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(colNames: Seq[String]): DataFrame = { + def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan { val groupCols = colNames.map(resolve) val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => @@ -1355,7 +1373,7 @@ class DataFrame private[sql]( * @since 1.3.1 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = { + def describe(cols: String*): DataFrame = withPlan { // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( @@ -1505,7 +1523,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def repartition(numPartitions: Int): DataFrame = { + def repartition(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = true, logicalPlan) } @@ -1519,7 +1537,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = { + def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) } @@ -1533,7 +1551,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): DataFrame = { + def repartition(partitionExprs: Column*): DataFrame = withPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) } @@ -1545,7 +1563,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - def coalesce(numPartitions: Int): DataFrame = { + def coalesce(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -2066,7 +2084,14 @@ class DataFrame private[sql]( SortOrder(expr, Ascending) } } - Sort(sortOrder, global = global, logicalPlan) + withPlan { + Sort(sortOrder, global = global, logicalPlan) + } + } + + /** A convenient function to wrap a logical plan and produce a DataFrame. */ + @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b75aeec4cf3a466a5f3b23b4dc3c9316adb741f..500227e93a472ab5a13fc42393547de6615b0af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -107,13 +107,16 @@ class Dataset[T] private( * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] * objects that allow fields to be accessed by ordinal or name. */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) - /** * Returns this Dataset. * @since 1.6.0 */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset. def toDS(): Dataset[T] = this /**