Skip to content
Snippets Groups Projects
Commit 86ea64dd authored by Nong Li's avatar Nong Li Committed by Michael Armbrust
Browse files

[SPARK-12271][SQL] Improve error message when Dataset.as[ ] has incompatible schemas.

Author: Nong Li <nong@databricks.com>

Closes #10260 from nongli/spark-11271.
parent b24c12d7
No related branches found
No related tags found
No related merge requests found
...@@ -184,7 +184,7 @@ object ScalaReflection extends ScalaReflection { ...@@ -184,7 +184,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(optType)) = t val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType) val className = getClassNameFromType(optType)
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
WrapOption(constructorFor(optType, path, newTypePath)) WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType))
case t if t <:< localTypeOf[java.lang.Integer] => case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer] val boxedType = classOf[java.lang.Integer]
......
...@@ -251,6 +251,7 @@ case class ExpressionEncoder[T]( ...@@ -251,6 +251,7 @@ case class ExpressionEncoder[T](
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan) val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
val optimizedPlan = SimplifyCasts(analyzedPlan) val optimizedPlan = SimplifyCasts(analyzedPlan)
// In order to construct instances of inner classes (for example those declared in a REPL cell), // In order to construct instances of inner classes (for example those declared in a REPL cell),
......
...@@ -23,11 +23,9 @@ import scala.reflect.ClassTag ...@@ -23,11 +23,9 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkConf import org.apache.spark.SparkConf
import org.apache.spark.serializer._ import org.apache.spark.serializer._
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
/** /**
...@@ -295,13 +293,17 @@ case class UnwrapOption( ...@@ -295,13 +293,17 @@ case class UnwrapOption(
* Converts the result of evaluating `child` into an option, checking both the isNull bit and * Converts the result of evaluating `child` into an option, checking both the isNull bit and
* (in the case of reference types) equality with null. * (in the case of reference types) equality with null.
* @param child The expression to evaluate and wrap. * @param child The expression to evaluate and wrap.
* @param optType The type of this option.
*/ */
case class WrapOption(child: Expression) extends UnaryExpression { case class WrapOption(child: Expression, optType: DataType)
extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = ObjectType(classOf[Option[_]]) override def dataType: DataType = ObjectType(classOf[Option[_]])
override def nullable: Boolean = true override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = optType :: Nil
override def eval(input: InternalRow): Any = override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported") throw new UnsupportedOperationException("Only code-generated evaluation is supported")
......
...@@ -481,10 +481,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ...@@ -481,10 +481,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
} }
}
test("verify mismatching field names fail with a good error") {
val ds = Seq(ClassData("a", 1)).toDS()
val e = intercept[AnalysisException] {
ds.as[ClassData2].collect()
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage)
}
}
case class ClassData(a: String, b: Int) case class ClassData(a: String, b: Int)
case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer) case class ClassNullableData(a: String, b: Integer)
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment