Skip to content
Snippets Groups Projects
Commit 84acd08e authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-5602][SQL] Better support for creating DataFrame from local data collection

1. Added methods to create DataFrames from Seq[Product]
2. Added executeTake to avoid running a Spark job on LocalRelations.

Author: Reynold Xin <rxin@databricks.com>

Closes #4372 from rxin/localDataFrame and squashes the following commits:

f696858 [Reynold Xin] style checker.
839ef7f [Reynold Xin] [SPARK-5602][SQL] Better support for creating DataFrame from local data collection.
parent 206f9bc3
No related branches found
No related tags found
No related merge requests found
Showing
with 170 additions and 88 deletions
...@@ -211,7 +211,7 @@ trait ScalaReflection { ...@@ -211,7 +211,7 @@ trait ScalaReflection {
*/ */
def asRelation: LocalRelation = { def asRelation: LocalRelation = {
val output = attributesFor[A] val output = attributesFor[A]
LocalRelation(output, data) LocalRelation.fromProduct(output, data)
} }
} }
} }
...@@ -17,31 +17,34 @@ ...@@ -17,31 +17,34 @@
package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types.{StructType, StructField} import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField}
object LocalRelation { object LocalRelation {
def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output: Attribute*): LocalRelation = new LocalRelation(output)
def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( def apply(output1: StructField, output: StructField*): LocalRelation = {
StructType(output1 +: output).toAttributes new LocalRelation(StructType(output1 +: output).toAttributes)
) }
def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
val schema = StructType.fromAttributes(output)
LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema)))
}
} }
case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil)
extends LeafNode with analysis.MultiInstanceRelation { extends LeafNode with analysis.MultiInstanceRelation {
// TODO: Validate schema compliance.
def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData)
/** /**
* Returns an identical copy of this relation with new exprIds for all attributes. Different * Returns an identical copy of this relation with new exprIds for all attributes. Different
* attributes are required when a relation is going to be included multiple times in the same * attributes are required when a relation is going to be included multiple times in the same
* query. * query.
*/ */
override final def newInstance: this.type = { override final def newInstance(): this.type = {
LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
} }
override protected def stringArgs = Iterator(output) override protected def stringArgs = Iterator(output)
......
...@@ -19,11 +19,27 @@ package org.apache.spark.sql.types ...@@ -19,11 +19,27 @@ package org.apache.spark.sql.types
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
protected[sql] object DataTypeConversions { protected[sql] object DataTypeConversions {
def productToRow(product: Product, schema: StructType): Row = {
val mutableRow = new GenericMutableRow(product.productArity)
val schemaFields = schema.fields.toArray
var i = 0
while (i < mutableRow.length) {
mutableRow(i) =
ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType)
i += 1
}
mutableRow
}
def stringToTime(s: String): java.util.Date = { def stringToTime(s: String): java.util.Date = {
if (!s.contains('T')) { if (!s.contains('T')) {
// JDBC escape string // JDBC escape string
......
...@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection ...@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution._
import org.apache.spark.sql.json._ import org.apache.spark.sql.json._
...@@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext) ...@@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Removes the specified table from the in-memory cache. */ /** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
*/
object implicits {
// scalastyle:on
/**
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
}
/** /**
* Creates a DataFrame from an RDD of case classes. * Creates a DataFrame from an RDD of case classes.
* *
* @group userf * @group userf
*/ */
implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { // TODO: Remove implicit here.
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self) SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema) val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self))
}
/**
* Creates a DataFrame from a local Seq of Product.
*/
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
DataFrame(self, LocalRelation.fromProduct(attributeSeq, data))
} }
/** /**
......
...@@ -54,12 +54,13 @@ object RDDConversions { ...@@ -54,12 +54,13 @@ object RDDConversions {
} }
} }
/** Logical plan node for scanning data from an RDD. */
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation { extends LogicalPlan with MultiInstanceRelation {
def children = Nil override def children = Nil
def newInstance() = override def newInstance() =
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan) = plan match { override def sameResult(plan: LogicalPlan) = plan match {
...@@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont ...@@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont
) )
} }
/** Physical plan node for scanning data from an RDD. */
case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd override def execute() = rdd
} }
@deprecated("Use LogicalRDD", "1.2.0") /** Logical plan node for scanning data from a local collection. */
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
override def execute() = rdd extends LogicalPlan with MultiInstanceRelation {
}
@deprecated("Use LogicalRDD", "1.2.0")
case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
def output = alreadyPlanned.output
override def children = Nil override def children = Nil
override final def newInstance(): this.type = { override def newInstance() =
SparkLogicalPlan( LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type]
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
case _ => sys.error("Multiple instance of the same relation detected.")
})(sqlContext).asInstanceOf[this.type]
}
override def sameResult(plan: LogicalPlan) = plan match { override def sameResult(plan: LogicalPlan) = plan match {
case SparkLogicalPlan(ExistingRdd(_, rdd)) => case LogicalRDD(_, otherRDD) => rows == rows
rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
case _ => false case _ => false
} }
@transient override lazy val statistics = Statistics( @transient override lazy val statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size // TODO: Improve the statistics estimation.
// estimate for RDDs. See PR 1238 for more discussions. // This is made small enough so it can be broadcasted.
sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1
) )
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
/**
* Physical plan node for scanning data from a local collection.
*/
case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode {
private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
override def execute() = rdd
override def executeCollect() = rows.toArray
override def executeTake(limit: Int) = rows.take(limit).toArray
}
...@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ ...@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import scala.collection.mutable.ArrayBuffer
object SparkPlan { object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]() protected[sql] val currentContext = new ThreadLocal[SQLContext]()
} }
...@@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** /**
* Runs this query returning the result as an array. * Runs this query returning the result as an array.
*/ */
def executeCollect(): Array[Row] = def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
}
/**
* Runs this query returning the first `n` rows as an array.
*
* This is modeled after RDD.take but never runs any job locally on the driver.
*/
def executeTake(n: Int): Array[Row] = {
if (n == 0) {
return new Array[Row](0)
}
val childRDD = execute().map(_.copy())
val buf = new ArrayBuffer[Row]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = n - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
}
buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}
protected def newProjection( protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
......
...@@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution} ...@@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.parquet._ import org.apache.spark.sql.parquet._
...@@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ...@@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) => case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) => case logical.LocalRelation(output, data) =>
val nPartitions = if (data.isEmpty) 1 else numPartitions LocalTableScan(output, data) :: Nil
PhysicalRDD(
output,
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
StructType.fromAttributes(output))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) => case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) => case Unions(unionChildren) =>
......
...@@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan) ...@@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output = child.output override def output = child.output
override def outputPartitioning = SinglePartition override def outputPartitioning = SinglePartition
/** override def executeCollect(): Array[Row] = child.executeTake(limit)
* A custom implementation modeled after the take function on RDDs but which never runs any job
* locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
* rows.
*/
override def executeCollect(): Array[Row] = {
if (limit == 0) {
return new Array[Row](0)
}
val childRDD = child.execute().map(_.copy())
val buf = new ArrayBuffer[Row]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.size < limit && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = limit - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
res.foreach(buf ++= _.take(limit - buf.size))
partsScanned += numPartsToTry
}
buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}
override def execute() = { override def execute() = {
val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
......
...@@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { ...@@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeCollect(): Array[Row] = sideEffectResult.toArray override def executeCollect(): Array[Row] = sideEffectResult.toArray
override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment