From c4008480b781379ac0451b9220300d83c054c60d Mon Sep 17 00:00:00 2001
From: Takeshi Yamamuro <yamamuro@apache.org>
Date: Wed, 29 Mar 2017 12:37:49 -0700
Subject: [PATCH] [SPARK-20009][SQL] Support DDL strings for defining schema in
 functions.from_json

## What changes were proposed in this pull request?
This pr added `StructType.fromDDL`  to convert a DDL format string into `StructType` for defining schemas in `functions.from_json`.

## How was this patch tested?
Added tests in `JsonFunctionsSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17406 from maropu/SPARK-20009.
---
 .../apache/spark/sql/types/StructType.scala   |  6 ++
 .../spark/sql/types/DataTypeSuite.scala       | 85 ++++++++++++++-----
 .../org/apache/spark/sql/functions.scala      | 15 +++-
 .../apache/spark/sql/JsonFunctionsSuite.scala |  7 ++
 .../sql/sources/SimpleTextRelation.scala      |  2 +-
 5 files changed, 90 insertions(+), 25 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 8d8b5b86d5..54006e20a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -417,6 +417,12 @@ object StructType extends AbstractDataType {
     }
   }
 
+  /**
+   * Creates StructType for a given DDL-formatted string, which is a comma separated list of field
+   * definitions, e.g., a INT, b STRING.
+   */
+  def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl)
+
   def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
 
   def apply(fields: java.util.List[StructField]): StructType = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 61e1ec7c7a..05cb999af6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite {
     assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
   }
 
-  def checkDataTypeJsonRepr(dataType: DataType): Unit = {
-    test(s"JSON - $dataType") {
+  def checkDataTypeFromJson(dataType: DataType): Unit = {
+    test(s"from Json - $dataType") {
       assert(DataType.fromJson(dataType.json) === dataType)
     }
   }
 
-  checkDataTypeJsonRepr(NullType)
-  checkDataTypeJsonRepr(BooleanType)
-  checkDataTypeJsonRepr(ByteType)
-  checkDataTypeJsonRepr(ShortType)
-  checkDataTypeJsonRepr(IntegerType)
-  checkDataTypeJsonRepr(LongType)
-  checkDataTypeJsonRepr(FloatType)
-  checkDataTypeJsonRepr(DoubleType)
-  checkDataTypeJsonRepr(DecimalType(10, 5))
-  checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
-  checkDataTypeJsonRepr(DateType)
-  checkDataTypeJsonRepr(TimestampType)
-  checkDataTypeJsonRepr(StringType)
-  checkDataTypeJsonRepr(BinaryType)
-  checkDataTypeJsonRepr(ArrayType(DoubleType, true))
-  checkDataTypeJsonRepr(ArrayType(StringType, false))
-  checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
-  checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+  def checkDataTypeFromDDL(dataType: DataType): Unit = {
+    test(s"from DDL - $dataType") {
+      val parsed = StructType.fromDDL(s"a ${dataType.sql}")
+      val expected = new StructType().add("a", dataType)
+      assert(parsed.sameType(expected))
+    }
+  }
+
+  checkDataTypeFromJson(NullType)
+
+  checkDataTypeFromJson(BooleanType)
+  checkDataTypeFromDDL(BooleanType)
+
+  checkDataTypeFromJson(ByteType)
+  checkDataTypeFromDDL(ByteType)
+
+  checkDataTypeFromJson(ShortType)
+  checkDataTypeFromDDL(ShortType)
+
+  checkDataTypeFromJson(IntegerType)
+  checkDataTypeFromDDL(IntegerType)
+
+  checkDataTypeFromJson(LongType)
+  checkDataTypeFromDDL(LongType)
+
+  checkDataTypeFromJson(FloatType)
+  checkDataTypeFromDDL(FloatType)
+
+  checkDataTypeFromJson(DoubleType)
+  checkDataTypeFromDDL(DoubleType)
+
+  checkDataTypeFromJson(DecimalType(10, 5))
+  checkDataTypeFromDDL(DecimalType(10, 5))
+
+  checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT)
+  checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT)
+
+  checkDataTypeFromJson(DateType)
+  checkDataTypeFromDDL(DateType)
+
+  checkDataTypeFromJson(TimestampType)
+  checkDataTypeFromDDL(TimestampType)
+
+  checkDataTypeFromJson(StringType)
+  checkDataTypeFromDDL(StringType)
+
+  checkDataTypeFromJson(BinaryType)
+  checkDataTypeFromDDL(BinaryType)
+
+  checkDataTypeFromJson(ArrayType(DoubleType, true))
+  checkDataTypeFromDDL(ArrayType(DoubleType, true))
+
+  checkDataTypeFromJson(ArrayType(StringType, false))
+  checkDataTypeFromDDL(ArrayType(StringType, false))
+
+  checkDataTypeFromJson(MapType(IntegerType, StringType, true))
+  checkDataTypeFromDDL(MapType(IntegerType, StringType, true))
+
+  checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false))
+  checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false))
 
   val metadata = new MetadataBuilder()
     .putString("name", "age")
@@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite {
     StructField("a", IntegerType, nullable = true),
     StructField("b", ArrayType(DoubleType), nullable = false),
     StructField("c", DoubleType, nullable = false, metadata)))
-  checkDataTypeJsonRepr(structType)
+  checkDataTypeFromJson(structType)
+  checkDataTypeFromDDL(structType)
 
   def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
     test(s"Check the default size of $dataType") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index acdb8e2d3e..0f9203065e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
 import scala.util.Try
+import scala.util.control.NonFatal
 
 import org.apache.spark.annotation.{Experimental, InterfaceStability}
 import org.apache.spark.sql.catalyst.ScalaReflection
@@ -3055,13 +3056,21 @@ object functions {
    * with the specified schema. Returns `null`, in the case of an unparseable string.
    *
    * @param e a string column containing JSON data.
-   * @param schema the schema to use when parsing the json string as a json string
+   * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1,
+   *               the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL
+   *               format is also supported for the schema.
    *
    * @group collection_funcs
    * @since 2.1.0
    */
-  def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
-    from_json(e, DataType.fromJson(schema), options)
+  def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = {
+    val dataType = try {
+      DataType.fromJson(schema)
+    } catch {
+      case NonFatal(_) => StructType.fromDDL(schema)
+    }
+    from_json(e, dataType, options)
+  }
 
   /**
    * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 170c238c53..8465e8d036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
       Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
   }
 
+  test("from_json uses DDL strings for defining a schema") {
+    val df = Seq("""{"a": 1, "b": "haa"}""").toDS()
+    checkAnswer(
+      df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())),
+      Row(Row(1, "haa")) :: Nil)
+  }
+
   test("to_json - struct") {
     val df = Seq(Tuple1(Tuple1(1))).toDF("a")
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 1607c97cd6..9f4009bfe4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 
-import org.apache.spark.sql.{sources, Row, SparkSession}
+import org.apache.spark.sql.{sources, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-- 
GitLab