Skip to content
Snippets Groups Projects
Commit c238fb42 authored by Cheng Lian's avatar Cheng Lian Committed by Michael Armbrust
Browse files

[SPARK-4202][SQL] Simple DSL support for Scala UDF

This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API.

For the following test snippet

```scala
case class KeyValue(key: Int, value: String)
val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD
def foo(a: Int, b: String) => a.toString + b
```

the newly introduced DSL enables the following syntax

```scala
import org.apache.spark.sql.catalyst.dsl._
testData.select(Star(None), foo.call('key, 'value) as 'result)
```

which is equivalent to

```scala
testData.registerTempTable("testData")
sqlContext.registerFunction("foo", foo)
sql("SELECT *, foo(key, value) AS result FROM testData")
```

Author: Cheng Lian <lian@databricks.com>

Closes #3067 from liancheng/udf-dsl and squashes the following commits:

f132818 [Cheng Lian] Adds DSL support for Scala UDF
parent 24544fbc
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
......@@ -285,4 +286,62 @@ package object dsl {
def writeToFile(path: String) = WriteToFile(path, logicalPlan)
}
}
case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
}
// scalastyle:off
/** functionToUdfBuilder 1-22 were generated by this script
(1 to 22).map { x =>
val argTypes = Seq.fill(x)("_").mkString(", ")
s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)"
}
*/
implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
// scalastyle:on
}
......@@ -19,14 +19,13 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.test.TestSQLContext._
class DslQuerySuite extends QueryTest {
import TestData._
import org.apache.spark.sql.TestData._
test("table scan") {
checkAnswer(
......@@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest {
(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
test("udf") {
val foo = (a: Int, b: String) => a.toString + b
checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select(Star(None), foo.call('key, 'value)).limit(3),
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment