Skip to content
Snippets Groups Projects
Commit 3b9d2a34 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11819][SQL] nice error message for missing encoder

before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`.

After this PR, the error message become more friendly, for example:
```
No Encoder found for abc.xyz.NonEncodable
- array element class: "abc.xyz.NonEncodable"
- field (class: "scala.Array", name: "arrayField")
- root class: "abc.xyz.AnotherClass"
```

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9810 from cloud-fan/error-message.
parent 60bfb113
No related branches found
No related tags found
No related merge requests found
......@@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case _ =>
val className: String = tpe.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(tpe)
className match {
case "scala.Array" =>
val TypeRef(_, _, Seq(elementType)) = tpe
......@@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection {
}
}
/** Returns expressions for extracting all the fields from the given type. */
/**
* Returns expressions for extracting all the fields from the given type.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
* the type path walked so far and which class we are not supporting.
* There are 4 kinds of type path:
* * the root type: `root class: "abc.xyz.MyClass"`
* * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
extractorFor(inputObject, localTypeOf[T]) match {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
extractorFor(inputObject, tpe, walkedTypePath) match {
case s: CreateNamedStruct => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
......@@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection {
/** Helper for extracting internal fields from a case class. */
private def extractorFor(
inputObject: Expression,
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
tpe: `Type`,
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
if (isNativeType(catalystType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
// `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
// to trigger the type check.
extractorFor(inputObject, elementType, newPath)
MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
}
}
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
......@@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection {
// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className: String = optType.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(optType)
val classObj = Utils.classForName(className)
val optionObjectType = ObjectType(classObj)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
val unwrapped = UnwrapOption(optionObjectType, inputObject)
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, schemaFor(optType).dataType),
extractorFor(unwrapped, optType))
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
extractorFor(unwrapped, optType, newPath))
}
case t if t <:< localTypeOf[Product] =>
......@@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection {
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})
case t if t <:< localTypeOf[Array[_]] =>
......@@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection {
Invoke(inputObject, "booleanValue", BooleanType)
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
}
}
private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = schemaFor(elementType)
if (isNativeType(catalystType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
MapObjects(extractorFor(_, elementType), input, externalDataType)
}
}
}
/**
......@@ -561,7 +588,7 @@ trait ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(tpe)
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
......@@ -637,6 +664,23 @@ trait ScalaReflection {
}
}
/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
*
* Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
* `NullType` silently instead.
*/
private def silentSchemaFor(tpe: `Type`): Schema = try {
schemaFor(tpe)
} catch {
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
}
/** Returns the full class name for a type. */
private def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}
/**
* Returns classes of input parameters of scala function object.
*/
......
......@@ -17,9 +17,22 @@
package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
class NonEncodable(i: Int)
case class ComplexNonEncodable1(name1: NonEncodable)
case class ComplexNonEncodable2(name2: ComplexNonEncodable1)
case class ComplexNonEncodable3(name3: Option[NonEncodable])
case class ComplexNonEncodable4(name4: Array[NonEncodable])
case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])
class EncoderErrorMessageSuite extends SparkFunSuite {
......@@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] }
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
}
test("nice error message for missing encoder") {
val errorMsg1 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
assert(errorMsg1.contains(
s"""root class: "${clsName[ComplexNonEncodable1]}""""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
val errorMsg2 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
assert(errorMsg2.contains(
s"""root class: "${clsName[ComplexNonEncodable2]}""""))
assert(errorMsg2.contains(
s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
val errorMsg3 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
assert(errorMsg3.contains(
s"""root class: "${clsName[ComplexNonEncodable3]}""""))
assert(errorMsg3.contains(
s"""field (class: "scala.Option", name: "name3")"""))
assert(errorMsg3.contains(
s"""option value class: "${clsName[NonEncodable]}""""))
val errorMsg4 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
assert(errorMsg4.contains(
s"""root class: "${clsName[ComplexNonEncodable4]}""""))
assert(errorMsg4.contains(
s"""field (class: "scala.Array", name: "name4")"""))
assert(errorMsg4.contains(
s"""array element class: "${clsName[NonEncodable]}""""))
val errorMsg5 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
assert(errorMsg5.contains(
s"""root class: "${clsName[ComplexNonEncodable5]}""""))
assert(errorMsg5.contains(
s"""field (class: "scala.Option", name: "name5")"""))
assert(errorMsg5.contains(
s"""option value class: "scala.Array""""))
assert(errorMsg5.contains(
s"""array element class: "${clsName[NonEncodable]}""""))
}
private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment