Skip to content
Snippets Groups Projects
Commit c5bfe5d2 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-13440][SQL] ObjectType should accept any ObjectType, If should not care about nullability

The type checking functions of `If` and `UnwrapOption` are fixed to eliminate spurious failures.  `UnwrapOption` was checking for an input of `ObjectType` but `ObjectType`'s accept function was hard coded to return `false`.  `If`'s type check was returning a false negative in the case that the two options differed only by nullability.

Tests added:
 -  an end-to-end regression test is added to `DatasetSuite` for the reported failure.
 - all the unit tests in `ExpressionEncoderSuite` are augmented to also confirm successful analysis.  These tests are actually what pointed out the additional issues with `If` resolution.

Author: Michael Armbrust <michael@databricks.com>

Closes #11316 from marmbrus/datasetOptions.
parent 9f426339
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ...@@ -34,7 +34,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
if (predicate.dataType != BooleanType) { if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure( TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}") s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) { } else if (trueValue.dataType.asNullable != falseValue.dataType.asNullable) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else { } else {
......
...@@ -45,6 +45,9 @@ object LocalRelation { ...@@ -45,6 +45,9 @@ object LocalRelation {
case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
extends LeafNode with analysis.MultiInstanceRelation { extends LeafNode with analysis.MultiInstanceRelation {
// A local relation must have resolved output.
require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.")
/** /**
* Returns an identical copy of this relation with new exprIds for all attributes. Different * Returns an identical copy of this relation with new exprIds for all attributes. Different
* attributes are required when a relation is going to be included multiple times in the same * attributes are required when a relation is going to be included multiple times in the same
......
...@@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType { ...@@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = override private[sql] def defaultConcreteType: DataType =
throw new UnsupportedOperationException("null literals can't be casted to ObjectType") throw new UnsupportedOperationException("null literals can't be casted to ObjectType")
// No casting or comparison is supported. override private[sql] def acceptsType(other: DataType): Boolean = other match {
override private[sql] def acceptsType(other: DataType): Boolean = false case ObjectType(_) => true
case _ => false
}
override private[sql] def simpleString: String = "Object" override private[sql] def simpleString: String = "Object"
} }
......
...@@ -60,7 +60,18 @@ trait AnalysisTest extends PlanTest { ...@@ -60,7 +60,18 @@ trait AnalysisTest extends PlanTest {
inputPlan: LogicalPlan, inputPlan: LogicalPlan,
caseSensitive: Boolean = true): Unit = { caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive) val analyzer = getAnalyzer(caseSensitive)
analyzer.checkAnalysis(analyzer.execute(inputPlan)) val analysisAttempt = analyzer.execute(inputPlan)
try analyzer.checkAnalysis(analysisAttempt) catch {
case a: AnalysisException =>
fail(
s"""
|Failed to Analyze Plan
|$inputPlan
|
|Partial Analysis
|$analysisAttempt
""".stripMargin, a)
}
} }
protected def assertAnalysisError( protected def assertAnalysisError(
......
...@@ -23,12 +23,14 @@ import java.util.Arrays ...@@ -23,12 +23,14 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, StructType} import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
case class RepeatedStruct(s: Seq[PrimitiveData]) case class RepeatedStruct(s: Seq[PrimitiveData])
...@@ -74,7 +76,7 @@ class JavaSerializable(val value: Int) extends Serializable { ...@@ -74,7 +76,7 @@ class JavaSerializable(val value: Int) extends Serializable {
} }
} }
class ExpressionEncoderSuite extends SparkFunSuite { class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this) OuterScopes.addOuterScope(this)
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
...@@ -305,6 +307,15 @@ class ExpressionEncoderSuite extends SparkFunSuite { ...@@ -305,6 +307,15 @@ class ExpressionEncoderSuite extends SparkFunSuite {
""".stripMargin, e) """.stripMargin, e)
} }
// Test the correct resolution of serialization / deserialization.
val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
val inputPlan = LocalRelation(attr)
val plan =
Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
Project(encoder.namedExpressions,
inputPlan))
assertAnalysisSuccess(plan)
val isCorrect = (input, convertedBack) match { val isCorrect = (input, convertedBack) match {
case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)
......
...@@ -613,6 +613,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ...@@ -613,6 +613,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
" - Input schema: struct<a:string,b:int>\n" + " - Input schema: struct<a:string,b:int>\n" +
" - Target schema: struct<_1:string>") " - Target schema: struct<_1:string>")
} }
test("SPARK-13440: Resolving option fields") {
val df = Seq(1, 2, 3).toDS()
val ds = df.as[Option[Int]]
checkAnswer(
ds.filter(_ => true),
Some(1), Some(2), Some(3))
}
} }
class OuterClass extends Serializable { class OuterClass extends Serializable {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment