From ccba622e35741d8344ec8d74b6750529b2c7219b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro <yamamuro@apache.org> Date: Sat, 18 Mar 2017 14:40:16 +0800 Subject: [PATCH] [SPARK-19896][SQL] Throw an exception if case classes have circular references in toDS ## What changes were proposed in this pull request? If case classes have circular references below, it throws StackOverflowError; ``` scala> :pasge case class classA(i: Int, cls: classB) case class classB(cls: classA) scala> Seq(classA(0, null)).toDS() java.lang.StackOverflowError at scala.reflect.internal.Symbols$Symbol.info(Symbols.scala:1494) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.scala$reflect$runtime$SynchronizedSymbols$SynchronizedSymbol$$super$info(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.Gil$class.gilSynchronized(Gil.scala:19) at scala.reflect.runtime.JavaUniverse.gilSynchronized(JavaUniverse.scala:16) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.gilSynchronizedIfNotThreadsafe(SynchronizedSymbols.scala:123) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.gilSynchronizedIfNotThreadsafe(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.info(SynchronizedSymbols.scala:127) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.info(JavaMirrors.scala:66) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:48) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) ``` This pr added code to throw UnsupportedOperationException in that case as follows; ``` scala> :paste case class A(cls: B) case class B(cls: A) scala> Seq(A(null)).toDS() java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class B at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:627) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:644) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:632) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) ``` ## How was this patch tested? Added tests in `DatasetSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17318 from maropu/SPARK-19896. --- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa2..c4af284f73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -470,14 +470,15 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + walkedTypePath: Seq[String], + seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, dt) + MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => @@ -511,7 +512,7 @@ object ScalaReflection extends ScalaReflection { val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) + serializerFor(unwrapped, optType, newPath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -534,9 +535,9 @@ object ScalaReflection extends ScalaReflection { ExternalMapToCatalyst( inputObject, dataTypeFor(keyType), - serializerFor(_, keyType, keyPath), + serializerFor(_, keyType, keyPath, seenTypeSet), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath), + serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) case t if t <:< localTypeOf[String] => @@ -622,6 +623,11 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t") + } + val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { @@ -634,7 +640,8 @@ object ScalaReflection extends ScalaReflection { returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: + serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) @@ -643,7 +650,6 @@ object ScalaReflection extends ScalaReflection { throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } - } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8..6417e7a8b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1136,6 +1136,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("SPARK-19896: cannot have circular references in in case class") { + val errMsg1 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassA(null)).toDS + } + assert(errMsg1.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg2 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassC(null)).toDS + } + assert(errMsg2.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg3 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassD(null)).toDS + } + assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) @@ -1214,3 +1232,9 @@ object DatasetTransform { case class Route(src: String, dest: String, cost: Int) case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) + +case class CircularReferenceClassA(cls: CircularReferenceClassB) +case class CircularReferenceClassB(cls: CircularReferenceClassA) +case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) +case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) +case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) -- GitLab