Skip to content
Snippets Groups Projects
Commit 1995c2e6 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer

## What changes were proposed in this pull request?

Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are:

* It doesn't work under HiveContext (in Spark 1.6)
* Race conditions

## How was this patch tested?

* Manual test with HiveContext.
* Added a unit test for `transformSchema` to improve coverage.

cc: yhuai

Author: Xiangrui Meng <meng@databricks.com>

Closes #12330 from mengxr/SPARK-14563.
parent 7f024c47
No related branches found
No related tags found
No related merge requests found
......@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
val outputDF = dataset.sqlContext.sql(realStatement)
outputDF
dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
......@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
dummyDF.registerTempTable(tableIdentifier)
val outputSchema = sqlContext.sql($(statement)).schema
val tableName = Identifiable.randomUID(uid)
val realStatement = $(statement).replace(tableIdentifier, tableName)
dummyDF.registerTempTable(tableName)
val outputSchema = sqlContext.sql(realStatement).schema
sqlContext.dropTempTable(tableName)
outputSchema
}
......
......@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.{LongType, StructField, StructType}
class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -49,4 +50,13 @@ class SQLTransformerSuite
.setStatement("select * from __THIS__")
testDefaultReadWrite(t)
}
test("transformSchema") {
val df = sqlContext.range(10)
val outputSchema = new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
.transformSchema(df.schema)
val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
assert(outputSchema === expected)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment