From 30c074308f723f95823b43fbc54e2e4742d52840 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@databricks.com> Date: Mon, 5 Dec 2016 10:49:22 -0800 Subject: [PATCH] Revert "[SPARK-18284][SQL] Make ExpressionEncoder.serializer.nullable precise" This reverts commit fce1be6cc81b1fe3991a4df91128f4fcd14ff615 from branch-2.1. --- .../sql/catalyst/JavaTypeInference.scala | 4 +- .../spark/sql/catalyst/ScalaReflection.scala | 7 +-- .../catalyst/encoders/ExpressionEncoder.scala | 7 ++- .../expressions/ReferenceToExpressions.scala | 2 +- .../expressions/objects/objects.scala | 24 ++++----- .../encoders/ExpressionEncoderSuite.scala | 19 +------ .../org/apache/spark/sql/DatasetSuite.scala | 52 +------------------ .../sql/streaming/FileStreamSinkSuite.scala | 2 +- 8 files changed, 21 insertions(+), 96 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7e8e4dab72..04f0cfce88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -396,14 +396,12 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - ExternalMapToCatalyst( inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), ObjectType(valueType.getRawType), - serializerFor(_, valueType), - valueNullable = true + serializerFor(_, valueType) ) case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6e20096901..0aa21b9347 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -498,8 +498,7 @@ object ScalaReflection extends ScalaReflection { dataTypeFor(keyType), serializerFor(_, keyType, keyPath), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath), - valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + serializerFor(_, valueType, valuePath)) case t if t <:< localTypeOf[String] => StaticInvoke( @@ -591,9 +590,7 @@ object ScalaReflection extends ScalaReflection { "cannot be used as field name\n" + walkedTypePath.mkString("\n")) } - val fieldValue = Invoke( - AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType), - returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3757eccfa2..9c4818db63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,7 +60,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) val nullSafeInput = if (flat) { inputObject } else { @@ -71,7 +71,10 @@ object ExpressionEncoder { val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] - val schema = serializer.dataType + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } new ExpressionEncoder[T]( schema, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 2ca77e8394..6c75a7a502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) ctx.addMutableState("boolean", classChildVarIsNull, "") val classChildVar = - LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable) + LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) val initCode = s"${classChildVar.value} = ${childGen.value};\n" + s"${classChildVar.isNull} = ${childGen.isNull};" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a8aa1e7255..e517ec18eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -171,18 +171,15 @@ case class StaticInvoke( * @param arguments An optional list of expressions, whos evaluation will be passed to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. - * @param returnNullable When false, indicating the invoked method will always return - * non-null value. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true, - returnNullable : Boolean = true) extends InvokeLike { + propagateNull: Boolean = true) extends InvokeLike { - override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable + override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = @@ -408,15 +405,13 @@ case class WrapOption(child: Expression, optType: DataType) * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ -case class LambdaVariable( - value: String, - isNull: String, - dataType: DataType, - nullable: Boolean = true) extends LeafExpression +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression with Unevaluable with NonSQLExpression { + override def nullable: Boolean = true + override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + ExprCode(code = "", value = value, isNull = isNull) } } @@ -597,8 +592,7 @@ object ExternalMapToCatalyst { keyType: DataType, keyConverter: Expression => Expression, valueType: DataType, - valueConverter: Expression => Expression, - valueNullable: Boolean): ExternalMapToCatalyst = { + valueConverter: Expression => Expression): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id val valueName = "ExternalMapToCatalyst_value" + id @@ -607,11 +601,11 @@ object ExternalMapToCatalyst { ExternalMapToCatalyst( keyName, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, "false", keyType)), valueName, valueIsNull, valueType, - valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), + valueConverter(LambdaVariable(valueName, valueIsNull, valueType)), inputMap ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b769..4d896c2e38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -24,7 +24,7 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ @@ -300,11 +300,6 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") - encodeDecodeTest(Option(31), "option of int") - encodeDecodeTest(Option.empty[Int], "empty option of int") - encodeDecodeTest(Option("abc"), "option of string") - encodeDecodeTest(Option.empty[String], "empty option of string") - productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("nullable of encoder schema") { @@ -343,18 +338,6 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } - test("nullable of encoder serializer") { - def checkNullable[T: Encoder](nullable: Boolean): Unit = { - assert(encoderFor[T].serializer.forall(_.nullable === nullable)) - } - - // test for flat encoders - checkNullable[Int](false) - checkNullable[Option[Int]](true) - checkNullable[java.lang.Integer](true) - checkNullable[String](true) - } - test("null check for map key") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d31c766cb7..1174d7354f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -28,10 +28,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - -case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) -case class TestDataPoint2(x: Int, s: String) +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -972,53 +969,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(dataset.collect() sameElements Array(resultValue, resultValue)) } - test("SPARK-18284: Serializer should have correct nullable value") { - val df1 = Seq(1, 2, 3, 4).toDF - assert(df1.schema(0).nullable == false) - val df2 = Seq(Integer.valueOf(1), Integer.valueOf(2)).toDF - assert(df2.schema(0).nullable == true) - - val df3 = Seq(Seq(1, 2), Seq(3, 4)).toDF - assert(df3.schema(0).nullable == true) - assert(df3.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false) - val df4 = Seq(Seq("a", "b"), Seq("c", "d")).toDF - assert(df4.schema(0).nullable == true) - assert(df4.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true) - - val df5 = Seq((0, 1.0), (2, 2.0)).toDF("id", "v") - assert(df5.schema(0).nullable == false) - assert(df5.schema(1).nullable == false) - val df6 = Seq((0, 1.0, "a"), (2, 2.0, "b")).toDF("id", "v1", "v2") - assert(df6.schema(0).nullable == false) - assert(df6.schema(1).nullable == false) - assert(df6.schema(2).nullable == true) - - val df7 = (Tuple1(Array(1, 2, 3)) :: Nil).toDF("a") - assert(df7.schema(0).nullable == true) - assert(df7.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false) - - val df8 = (Tuple1(Array((null: Integer), (null: Integer))) :: Nil).toDF("a") - assert(df8.schema(0).nullable == true) - assert(df8.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true) - - val df9 = (Tuple1(Map(2 -> 3)) :: Nil).toDF("m") - assert(df9.schema(0).nullable == true) - assert(df9.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == false) - - val df10 = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("m") - assert(df10.schema(0).nullable == true) - assert(df10.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == true) - - val df11 = Seq(TestDataPoint(1, 2.2, "a", null), - TestDataPoint(3, 4.4, "null", (TestDataPoint2(33, "b")))).toDF - assert(df11.schema(0).nullable == false) - assert(df11.schema(1).nullable == false) - assert(df11.schema(2).nullable == true) - assert(df11.schema(3).nullable == true) - assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(0).nullable == false) - assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(1).nullable == true) - } - Seq(true, false).foreach { eager => def testCheckpointing(testName: String)(f: => Unit): Unit = { test(s"Dataset.checkpoint() - $testName (eager = $eager)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 54efae3fb4..09613ef9e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -86,7 +86,7 @@ class FileStreamSinkSuite extends StreamTest { val outputDf = spark.read.parquet(outputDir) val expectedSchema = new StructType() - .add(StructField("value", IntegerType, nullable = false)) + .add(StructField("value", IntegerType)) .add(StructField("id", IntegerType)) assert(outputDf.schema === expectedSchema) -- GitLab