From 693a323a70aba91e6c100dd5561d218a75b7895e Mon Sep 17 00:00:00 2001 From: scwf <wangfei1@huawei.com> Date: Sat, 10 Jan 2015 13:53:21 -0800 Subject: [PATCH] [SPARK-4574][SQL] Adding support for defining schema in foreign DDL commands. Adding support for defining schema in foreign DDL commands. Now foreign DDL support commands like: ``` CREATE TEMPORARY TABLE avroTable USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` With this PR user can define schema instead of infer from file, so support ddl command as follows: ``` CREATE TEMPORARY TABLE avroTable(a int, b string) USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` Author: scwf <wangfei1@huawei.com> Author: Yin Huai <yhuai@databricks.com> Author: Fei Wang <wangfei1@huawei.com> Author: wangfei <wangfei1@huawei.com> Closes #3431 from scwf/ddl and squashes the following commits: 7e79ce5 [Fei Wang] Merge pull request #22 from yhuai/pr3431yin 38f634e [Yin Huai] Remove Option from createRelation. 65e9c73 [Yin Huai] Revert all changes since applying a given schema has not been testd. a852b10 [scwf] remove cleanIdentifier f336a16 [Fei Wang] Merge pull request #21 from yhuai/pr3431yin baf79b5 [Yin Huai] Test special characters quoted by backticks. 50a03b0 [Yin Huai] Use JsonRDD.nullTypeToStringType to convert NullType to StringType. 1eeb769 [Fei Wang] Merge pull request #20 from yhuai/pr3431yin f5c22b0 [Yin Huai] Refactor code and update test cases. f1cffe4 [Yin Huai] Revert "minor refactory" b621c8f [scwf] minor refactory d02547f [scwf] fix HiveCompatibilitySuite test failure 8dfbf7a [scwf] more tests for complex data type ddab984 [Fei Wang] Merge pull request #19 from yhuai/pr3431yin 91ad91b [Yin Huai] Parse data types in DDLParser. cf982d2 [scwf] fixed test failure 445b57b [scwf] address comments 02a662c [scwf] style issue 44eb70c [scwf] fix decimal parser issue 83b6fc3 [scwf] minor fix 9bf12f8 [wangfei] adding test case 7787ec7 [wangfei] added SchemaRelationProvider 0ba70df [wangfei] draft version --- .../apache/spark/sql/json/JSONRelation.scala | 35 +++- .../org/apache/spark/sql/sources/ddl.scala | 138 +++++++++++-- .../apache/spark/sql/sources/interfaces.scala | 29 ++- .../spark/sql/sources/TableScanSuite.scala | 192 ++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 114 +++-------- .../sql/hive/HiveMetastoreCatalogSuite.scala | 5 +- 6 files changed, 400 insertions(+), 113 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index fc70c18343..a9a6696cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -18,31 +18,48 @@ package org.apache.spark.sql.json import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ -private[sql] class DefaultSource extends RelationProvider { - /** Returns a new base relation with the given parameters. */ +private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { + + /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio)(sqlContext) + JSONRelation(fileName, samplingRatio, None)(sqlContext) + } + + /** Returns a new base relation with the given schema and parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) } } -private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( +private[sql] case class JSONRelation( + fileName: String, + samplingRatio: Double, + userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) extends TableScan { private def baseRDD = sqlContext.sparkContext.textFile(fileName) - override val schema = - JsonRDD.inferSchema( - baseRDD, - samplingRatio, - sqlContext.columnNameOfCorruptRecord) + override val schema = userSpecifiedSchema.getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord))) override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 364bacec83..fe2c4d8436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.sources -import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.util.Utils - import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers import scala.util.parsing.combinator.PackratParsers +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical @@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } + def parseType(input: String): DataType = { + phrase(dataType)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => + sys.error(s"Unsupported dataType: $x") + } + } + protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = @@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + // Data types. + protected val STRING = Keyword("STRING") + protected val BINARY = Keyword("BINARY") + protected val BOOLEAN = Keyword("BOOLEAN") + protected val TINYINT = Keyword("TINYINT") + protected val SMALLINT = Keyword("SMALLINT") + protected val INT = Keyword("INT") + protected val BIGINT = Keyword("BIGINT") + protected val FLOAT = Keyword("FLOAT") + protected val DOUBLE = Keyword("DOUBLE") + protected val DECIMAL = Keyword("DECIMAL") + protected val DATE = Keyword("DATE") + protected val TIMESTAMP = Keyword("TIMESTAMP") + protected val VARCHAR = Keyword("VARCHAR") + protected val ARRAY = Keyword("ARRAY") + protected val MAP = Keyword("MAP") + protected val STRUCT = Keyword("STRUCT") + // Use reflection to find the reserved words defined in this class. protected val reservedWords = this.getClass @@ -67,15 +92,25 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val ddl: Parser[LogicalPlan] = createTable /** - * CREATE TEMPORARY TABLE avroTable + * `CREATE TEMPORARY TABLE avroTable * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...) + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = - CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case tableName ~ provider ~ opts => - CreateTableUsing(tableName, provider, opts) + ( + CREATE ~ TEMPORARY ~ TABLE ~> ident + ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ columns ~ provider ~ opts => + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing(tableName, userSpecifiedSchema, provider, opts) } + ) + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -83,10 +118,66 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ^^ { case columnName ~ typ => + StructField(columnName, typ) + } + + protected lazy val primitiveType: Parser[DataType] = + STRING ^^^ StringType | + BINARY ^^^ BinaryType | + BOOLEAN ^^^ BooleanType | + TINYINT ^^^ ByteType | + SMALLINT ^^^ ShortType | + INT ^^^ IntegerType | + BIGINT ^^^ LongType | + FLOAT ^^^ FloatType | + DOUBLE ^^^ DoubleType | + fixedDecimalType | // decimal with precision/scale + DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale + DATE ^^^ DateType | + TIMESTAMP ^^^ TimestampType | + VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + ARRAY ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields) + }) | + (STRUCT ~> "<>" ^^ { + case fields => new StructType(Nil) + }) + + private[sql] lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType } private[sql] case class CreateTableUsing( tableName: String, + userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing( sys.error(s"Failed to load class for data source: $provider") } } - val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.") + } + } + case None => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.RelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.") + } + } + } sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 02eff80456..990f7e0e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType} +import org.apache.spark.sql.{Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** @@ -44,6 +44,33 @@ trait RelationProvider { def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation } +/** + * ::DeveloperApi:: + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with + * 1. USING clause: to specify the implemented SchemaRelationProvider + * 2. User defined schema: users can define schema optionally when create table + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait SchemaRelationProvider { + /** + * Returns a new base relation with the given parameters and user defined schema. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation +} + /** * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 3cd7b0115d..605190f5ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.sources +import java.sql.{Timestamp, Date} + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.types.DecimalType class DefaultSource extends SimpleScanSource @@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } +class AllDataTypesScanSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) + } +} + +case class AllDataTypesScan( + from: Int, + to: Int, + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = userSpecifiedSchema + + override def buildScan() = { + sqlContext.sparkContext.parallelize(from to to).map { i => + Row( + s"str_$i", + s"str_$i".getBytes(), + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date((i + 1) * 8640000), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + } + } +} + class TableScanSuite extends DataSourceTest { import caseInsensisitiveContext._ + var tableWithSchemaExpected = (1 to 10).map { i => + Row( + s"str_$i", + s"str_$i", + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date((i + 1) * 8640000), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + }.toSeq + before { sql( """ @@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest { | To '10' |) """.stripMargin) + + sql( + """ + |CREATE TEMPORARY TABLE tableWithSchema ( + |`string$%Field` stRIng, + |binaryField binary, + |`booleanField` boolean, + |ByteField tinyint, + |shortField smaLlint, + |int_Field iNt, + |`longField_:,<>=+/~^` Bigint, + |floatField flOat, + |doubleField doubLE, + |decimalField1 decimal, + |decimalField2 decimal(9,2), + |dateField dAte, + |timestampField tiMestamp, + |varcharField varchaR(12), + |arrayFieldSimple Array<inT>, + |arrayFieldComplex Array<Map<String, Struct<key:bigInt>>>, + |mapFieldSimple MAP<iNt, StRing>, + |mapFieldComplex Map<Map<stRING, fLOAT>, Struct<key:bigInt>>, + |structFieldSimple StRuct<key:INt, Value:STrINg>, + |structFieldComplex StRuct<key:Array<String>, Value:struct<`value_(2)`:Array<date>>> + |) + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) } sqlTest( @@ -73,6 +175,96 @@ class TableScanSuite extends DataSourceTest { "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", (2 to 10).map(i => Row(i, i - 1)).toSeq) + test("Schema and all fields") { + val expectedSchema = StructType( + StructField("string$%Field", StringType, true) :: + StructField("binaryField", BinaryType, true) :: + StructField("booleanField", BooleanType, true) :: + StructField("ByteField", ByteType, true) :: + StructField("shortField", ShortType, true) :: + StructField("int_Field", IntegerType, true) :: + StructField("longField_:,<>=+/~^", LongType, true) :: + StructField("floatField", FloatType, true) :: + StructField("doubleField", DoubleType, true) :: + StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField2", DecimalType(9, 2), true) :: + StructField("dateField", DateType, true) :: + StructField("timestampField", TimestampType, true) :: + StructField("varcharField", StringType, true) :: + StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: + StructField("arrayFieldComplex", + ArrayType( + MapType(StringType, StructType(StructField("key", LongType, true) :: Nil))), true) :: + StructField("mapFieldSimple", MapType(IntegerType, StringType), true) :: + StructField("mapFieldComplex", + MapType( + MapType(StringType, FloatType), + StructType(StructField("key", LongType, true) :: Nil)), true) :: + StructField("structFieldSimple", + StructType( + StructField("key", IntegerType, true) :: + StructField("Value", StringType, true) :: Nil), true) :: + StructField("structFieldComplex", + StructType( + StructField("key", ArrayType(StringType), true) :: + StructField("Value", + StructType( + StructField("value_(2)", ArrayType(DateType), true) :: Nil), true) :: Nil), true) :: + Nil + ) + + assert(expectedSchema == table("tableWithSchema").schema) + + checkAnswer( + sql( + """SELECT + | `string$%Field`, + | cast(binaryField as string), + | booleanField, + | byteField, + | shortField, + | int_Field, + | `longField_:,<>=+/~^`, + | floatField, + | doubleField, + | decimalField1, + | decimalField2, + | dateField, + | timestampField, + | varcharField, + | arrayFieldSimple, + | arrayFieldComplex, + | mapFieldSimple, + | mapFieldComplex, + | structFieldSimple, + | structFieldComplex FROM tableWithSchema""".stripMargin), + tableWithSchemaExpected + ) + } + + sqlTest( + "SELECT count(*) FROM tableWithSchema", + 10) + + sqlTest( + "SELECT `string$%Field` FROM tableWithSchema", + (1 to 10).map(i => Row(s"str_$i")).toSeq) + + sqlTest( + "SELECT int_Field FROM tableWithSchema WHERE int_Field < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT `longField_:,<>=+/~^` * 2 FROM tableWithSchema", + (1 to 10).map(i => Row(i * 2.toLong)).toSeq) + + sqlTest( + "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1", + Seq(Seq(1, 2))) + + sqlTest( + "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", + (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2c859894cf..c25288e000 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,12 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import org.apache.spark.sql.execution.SparkPlan - -import scala.util.parsing.combinator.RegexParsers - import org.apache.hadoop.util.ReflectionUtils - import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} @@ -37,7 +32,6 @@ import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -412,88 +406,6 @@ private[hive] case class InsertIntoHiveTable( } } -/** - * :: DeveloperApi :: - * Provides conversions between Spark SQL data types and Hive Metastore types. - */ -@DeveloperApi -object HiveMetastoreTypes extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "string" ^^^ StringType | - "float" ^^^ FloatType | - "int" ^^^ IntegerType | - "tinyint" ^^^ ByteType | - "smallint" ^^^ ShortType | - "double" ^^^ DoubleType | - "bigint" ^^^ LongType | - "binary" ^^^ BinaryType | - "boolean" ^^^ BooleanType | - fixedDecimalType | // Hive 0.13+ decimal with precision/scale - "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale - "date" ^^^ DateType | - "timestamp" ^^^ TimestampType | - "varchar\\((\\d+)\\)".r ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { - case precision ~ scale => - DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - - protected lazy val mapType: Parser[DataType] = - "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ { - case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) - } - - protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { - case fields => new StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType - - def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") - } - - def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" - case StructType(fields) => - s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType, _) => - s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" - case StringType => "string" - case FloatType => "float" - case IntegerType => "int" - case ByteType => "tinyint" - case ShortType => "smallint" - case DoubleType => "double" - case LongType => "bigint" - case BinaryType => "binary" - case BooleanType => "boolean" - case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) - case TimestampType => "timestamp" - case NullType => "void" - case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) - } -} - private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) @@ -551,7 +463,7 @@ private[hive] case class MetastoreRelation implicit class SchemaAttribute(f: FieldSchema) { def toAttribute = AttributeReference( f.getName, - HiveMetastoreTypes.toDataType(f.getType), + sqlContext.ddlParser.parseType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -571,3 +483,27 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) } + +object HiveMetastoreTypes { + def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType, _) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ByteType => "tinyint" + case ShortType => "smallint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DateType => "date" + case d: DecimalType => HiveShim.decimalMetastoreString(d) + case TimestampType => "timestamp" + case NullType => "void" + case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 86535f8dd4..041a36f129 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.sources.DDLParser import org.apache.spark.sql.test.ExamplePointUDT class HiveMetastoreCatalogSuite extends FunSuite { @@ -27,7 +28,9 @@ class HiveMetastoreCatalogSuite extends FunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct<a: int, b_1: string, c: string>" - val datatype = HiveMetastoreTypes.toDataType(metastr) + val ddlParser = new DDLParser + + val datatype = ddlParser.parseType(metastr) assert(datatype.isInstanceOf[StructType]) } -- GitLab