diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 893af5146c5b3f2b4e187b858fdeffb88cb83431..83cb3755258324f2e9193a7c6fd8be70dec3f1f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // this child in all children. case (name, value: TreeNode[_]) if containsChild(value) => name -> JInt(children.indexOf(value)) - case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + case (name, value: Seq[BaseType]) if value.forall(containsChild) => name -> JArray( value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList ) @@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming // it to JSON may trigger OutOfMemoryError. case m: Metadata => Metadata.empty.jsonValue + case clazz: Class[_] => JString(clazz.getName) case s: StorageLevel => ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) case n: TreeNode[_] => n.jsonValue case o: Option[_] => o.map(parseToJson) - case t: Seq[_] => JArray(t.map(parseToJson).toList) - case m: Map[_, _] => - val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) } - JObject(fields) - case r: RDD[_] => JNothing + // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType] + case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) || + t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => + JArray(t.map(parseToJson).toList) + case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => + JString(Utils.truncatedString(t, "[", ", ", "]")) + case t: Seq[_] => JNull + case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. // TODO: currently if the class name ends with "$", we think it's a scala object, there is // probably a better way to check it. case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName - // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper - case p: Product => try { - val fieldNames = getConstructorParameterNames(p.getClass) - val fieldValues = p.productIterator.toSeq - assert(fieldNames.length == fieldValues.length) - ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { - case (name, value) => name -> parseToJson(value) - }.toList - } catch { - case _: RuntimeException => null - } - case _ => JNull - } -} - -object TreeNode { - def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = { - val jsonAST = parse(json) - assert(jsonAST.isInstanceOf[JArray]) - reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType] - } - - private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = { - assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject])) - val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*) - - def parseNextNode(): TreeNode[_] = { - val nextNode = jsonNodes.pop() - - val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s) - if (cls == classOf[Literal]) { - Literal.fromJSON(nextNode) - } else if (cls.getName.endsWith("$")) { - cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]] - } else { - val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt - - val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode()) - val fields = getConstructorParameters(cls) - - val parameters: Array[AnyRef] = fields.map { - case (fieldName, fieldType) => - parseFromJson(nextNode \ fieldName, fieldType, children, sc) - }.toArray - - val maybeCtor = cls.getConstructors.find { p => - val expectedTypes = p.getParameterTypes - expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall { - case (cls, tpe) => cls == getClassFromType(tpe) - } - } - if (maybeCtor.isEmpty) { - sys.error(s"No valid constructor for ${cls.getName}") - } else { - try { - maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]] - } catch { - case e: java.lang.IllegalArgumentException => - throw new RuntimeException( - s""" - |Failed to construct tree node: ${cls.getName} - |ctor: ${maybeCtor.get} - |types: ${parameters.map(_.getClass).mkString(", ")} - |args: ${parameters.mkString(", ")} - """.stripMargin, e) - } - } - } - } - - parseNextNode() - } - - import universe._ - - private def parseFromJson( - value: JValue, - expectedType: Type, - children: Seq[TreeNode[_]], - sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized { - if (value == JNull) return null - - expectedType match { - case t if t <:< definitions.BooleanTpe => - value.asInstanceOf[JBool].value: java.lang.Boolean - case t if t <:< definitions.ByteTpe => - value.asInstanceOf[JInt].num.toByte: java.lang.Byte - case t if t <:< definitions.ShortTpe => - value.asInstanceOf[JInt].num.toShort: java.lang.Short - case t if t <:< definitions.IntTpe => - value.asInstanceOf[JInt].num.toInt: java.lang.Integer - case t if t <:< definitions.LongTpe => - value.asInstanceOf[JInt].num.toLong: java.lang.Long - case t if t <:< definitions.FloatTpe => - value.asInstanceOf[JDouble].num.toFloat: java.lang.Float - case t if t <:< definitions.DoubleTpe => - value.asInstanceOf[JDouble].num: java.lang.Double - - case t if t <:< localTypeOf[java.lang.Boolean] => - value.asInstanceOf[JBool].value: java.lang.Boolean - case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num - case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s - case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) - case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) - case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) - case t if t <:< localTypeOf[StorageLevel] => - val JBool(useDisk) = value \ "useDisk" - val JBool(useMemory) = value \ "useMemory" - val JBool(useOffHeap) = value \ "useOffHeap" - val JBool(deserialized) = value \ "deserialized" - val JInt(replication) = value \ "replication" - StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt) - case t if t <:< localTypeOf[TreeNode[_]] => value match { - case JInt(i) => children(i.toInt) - case arr: JArray => reconstruct(arr, sc) - case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.") + case p: Product if shouldConvertToJson(p) => + try { + val fieldNames = getConstructorParameterNames(p.getClass) + val fieldValues = p.productIterator.toSeq + assert(fieldNames.length == fieldValues.length) + ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { + case (name, value) => name -> parseToJson(value) + }.toList + } catch { + case _: RuntimeException => null } - case t if t <:< localTypeOf[Option[_]] => - if (value == JNothing) { - None - } else { - val TypeRef(_, _, Seq(optType)) = t - Option(parseFromJson(value, optType, children, sc)) - } - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val JArray(elements) = value - elements.map(parseFromJson(_, elementType, children, sc)).toSeq - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val JObject(fields) = value - fields.map { - case (name, value) => name -> parseFromJson(value, valueType, children, sc) - }.toMap - case t if t <:< localTypeOf[RDD[_]] => - new EmptyRDD[Any](sc) - case _ if isScalaObject(value) => - val JString(clsName) = value \ "object" - val cls = Utils.classForName(clsName) - cls.getField("MODULE$").get(cls) - case t if t <:< localTypeOf[Product] => - val fields = getConstructorParameters(t) - val clsName = getClassNameFromType(t) - parseToProduct(clsName, fields, value, children, sc) - // There maybe some cases that the parameter type signature is not Product but the value is, - // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here. - case _ if isScalaProduct(value) => - val JString(clsName) = value \ "product-class" - val fields = getConstructorParameters(Utils.classForName(clsName)) - parseToProduct(clsName, fields, value, children, sc) - case _ => sys.error(s"Do not support type $expectedType with json $value.") - } - } - - private def parseToProduct( - clsName: String, - fields: Seq[(String, Type)], - value: JValue, - children: Seq[TreeNode[_]], - sc: SparkContext): AnyRef = { - val parameters: Array[AnyRef] = fields.map { - case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc) - }.toArray - val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size) - ctor.newInstance(parameters: _*).asInstanceOf[AnyRef] - } - - private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match { - case JString(str) if str.endsWith("$") => true - case _ => false + case _ => JNull } - private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match { - case _: JString => true + private def shouldConvertToJson(product: Product): Boolean = product match { + case exprId: ExprId => true + case field: StructField => true + case id: TableIdentifier => true + case join: JoinType => true + case id: FunctionIdentifier => true + case spec: BucketSpec => true + case catalog: CatalogTable => true + case boundary: FrameBoundary => true + case frame: WindowFrame => true + case partition: Partitioning => true + case resource: FunctionResource => true + case broadcast: BroadcastMode => true + case table: CatalogTableType => true + case storage: CatalogStorageFormat => true case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6246380dbeb9b9cc8b436984ccd22a5dcea79c2a..cb0426c7a98a18fbe7d01efa6d11cb32267e2be7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -17,13 +17,29 @@ package org.apache.spark.sql.catalyst.trees +import java.math.BigInteger +import java.util.UUID + import scala.collection.mutable.ArrayBuffer +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{IntegerType, NullType, StringType} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} +import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq @@ -45,6 +61,20 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Expression with override lazy val resolved = true } +case class JsonTestTreeNode(arg: Any) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty[Attribute] +} + +case class NameValue(name: String, value: Any) + +case object DummyObject + +case class SelfReferenceUDF( + var config: Map[String, Any] = Map.empty[String, Any]) extends Function1[String, Boolean] { + config += "self" -> this + def apply(key: String): Boolean = config.contains(key) +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -261,4 +291,264 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === expected) } } + + test("toJSON") { + def assertJSON(input: Any, json: JValue): Unit = { + val expected = + s""" + |[{ + | "class": "${classOf[JsonTestTreeNode].getName}", + | "num-children": 0, + | "arg": ${compact(render(json))} + |}] + """.stripMargin + compareJSON(JsonTestTreeNode(input).toJSON, expected) + } + + // Converts simple types to JSON + assertJSON(true, true) + assertJSON(33.toByte, 33) + assertJSON(44, 44) + assertJSON(55L, 55L) + assertJSON(3.0, 3.0) + assertJSON(4.0D, 4.0D) + assertJSON(BigInt(BigInteger.valueOf(88L)), 88L) + assertJSON(null, JNull) + assertJSON("text", "text") + assertJSON(Some("text"), "text") + compareJSON(JsonTestTreeNode(None).toJSON, + s"""[ + | { + | "class": "${classOf[JsonTestTreeNode].getName}", + | "num-children": 0 + | } + |] + """.stripMargin) + + val uuid = UUID.randomUUID() + assertJSON(uuid, uuid.toString) + + // Converts Spark Sql DataType to JSON + assertJSON(IntegerType, "integer") + assertJSON(Metadata.empty, JObject(Nil)) + assertJSON( + StorageLevel.NONE, + JObject( + "useDisk" -> false, + "useMemory" -> false, + "useOffHeap" -> false, + "deserialized" -> false, + "replication" -> 1) + ) + + // Converts TreeNode argument to JSON + assertJSON( + Literal(333), + List( + JObject( + "class" -> classOf[Literal].getName, + "num-children" -> 0, + "value" -> "333", + "dataType" -> "integer"))) + + // Converts Seq[String] to JSON + assertJSON(Seq("1", "2", "3"), "[1, 2, 3]") + + // Converts Seq[DataType] to JSON + assertJSON(Seq(IntegerType, DoubleType, FloatType), List("integer", "double", "float")) + + // Converts Seq[Partitioning] to JSON + assertJSON( + Seq(SinglePartition, RoundRobinPartitioning(numPartitions = 3)), + List( + JObject("object" -> JString(SinglePartition.getClass.getName)), + JObject( + "product-class" -> classOf[RoundRobinPartitioning].getName, + "numPartitions" -> 3))) + + // Converts case object to JSON + assertJSON(DummyObject, JObject("object" -> JString(DummyObject.getClass.getName))) + + // Converts ExprId to JSON + assertJSON( + ExprId(0, uuid), + JObject( + "product-class" -> classOf[ExprId].getName, + "id" -> 0, + "jvmId" -> uuid.toString)) + + // Converts StructField to JSON + assertJSON( + StructField("field", IntegerType), + JObject( + "product-class" -> classOf[StructField].getName, + "name" -> "field", + "dataType" -> "integer", + "nullable" -> true, + "metadata" -> JObject(Nil))) + + // Converts TableIdentifier to JSON + assertJSON( + TableIdentifier("table"), + JObject( + "product-class" -> classOf[TableIdentifier].getName, + "table" -> "table")) + + // Converts JoinType to JSON + assertJSON( + NaturalJoin(LeftOuter), + JObject( + "product-class" -> classOf[NaturalJoin].getName, + "tpe" -> JObject("object" -> JString(LeftOuter.getClass.getName)))) + + // Converts FunctionIdentifier to JSON + assertJSON( + FunctionIdentifier("function", None), + JObject( + "product-class" -> JString(classOf[FunctionIdentifier].getName), + "funcName" -> "function")) + + // Converts BucketSpec to JSON + assertJSON( + BucketSpec(1, Seq("bucket"), Seq("sort")), + JObject( + "product-class" -> classOf[BucketSpec].getName, + "numBuckets" -> 1, + "bucketColumnNames" -> "[bucket]", + "sortColumnNames" -> "[sort]")) + + // Converts FrameBoundary to JSON + assertJSON( + ValueFollowing(3), + JObject( + "product-class" -> classOf[ValueFollowing].getName, + "value" -> 3)) + + // Converts WindowFrame to JSON + assertJSON( + SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow), + JObject( + "product-class" -> classOf[SpecifiedWindowFrame].getName, + "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), + "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)), + "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName)))) + + // Converts Partitioning to JSON + assertJSON( + RoundRobinPartitioning(numPartitions = 3), + JObject( + "product-class" -> classOf[RoundRobinPartitioning].getName, + "numPartitions" -> 3)) + + // Converts FunctionResource to JSON + assertJSON( + FunctionResource(JarResource, "file:///"), + JObject( + "product-class" -> JString(classOf[FunctionResource].getName), + "resourceType" -> JObject("object" -> JString(JarResource.getClass.getName)), + "uri" -> "file:///")) + + // Converts BroadcastMode to JSON + assertJSON( + IdentityBroadcastMode, + JObject("object" -> JString(IdentityBroadcastMode.getClass.getName))) + + // Converts CatalogTable to JSON + assertJSON( + CatalogTable( + TableIdentifier("table"), + CatalogTableType.MANAGED, + CatalogStorageFormat.empty, + StructType(StructField("a", IntegerType, true) :: Nil), + createTime = 0L), + + JObject( + "product-class" -> classOf[CatalogTable].getName, + "identifier" -> JObject( + "product-class" -> classOf[TableIdentifier].getName, + "table" -> "table" + ), + "tableType" -> JObject( + "product-class" -> classOf[CatalogTableType].getName, + "name" -> "MANAGED" + ), + "storage" -> JObject( + "product-class" -> classOf[CatalogStorageFormat].getName, + "compressed" -> false, + "properties" -> JNull + ), + "schema" -> JObject( + "type" -> "struct", + "fields" -> List( + JObject( + "name" -> "a", + "type" -> "integer", + "nullable" -> true, + "metadata" -> JObject(Nil)))), + "partitionColumnNames" -> List.empty[String], + "owner" -> "", + "createTime" -> 0, + "lastAccessTime" -> -1, + "properties" -> JNull, + "unsupportedFeatures" -> List.empty[String])) + + // For unknown case class, returns JNull. + val bigValue = new Array[Int](10000) + assertJSON(NameValue("name", bigValue), JNull) + + // Converts Seq[TreeNode] to JSON recursively + assertJSON( + Seq(Literal(1), Literal(2)), + List( + List( + JObject( + "class" -> JString(classOf[Literal].getName), + "num-children" -> 0, + "value" -> "1", + "dataType" -> "integer")), + List( + JObject( + "class" -> JString(classOf[Literal].getName), + "num-children" -> 0, + "value" -> "2", + "dataType" -> "integer")))) + + // Other Seq is converted to JNull, to reduce the risk of out of memory + assertJSON(Seq(1, 2, 3), JNull) + + // All Map type is converted to JNull, to reduce the risk of out of memory + assertJSON(Map("key" -> "value"), JNull) + + // Unknown type is converted to JNull, to reduce the risk of out of memory + assertJSON(new Object {}, JNull) + + // Convert all TreeNode children to JSON + assertJSON( + Union(Seq(JsonTestTreeNode("0"), JsonTestTreeNode("1"))), + List( + JObject( + "class" -> classOf[Union].getName, + "num-children" -> 2, + "children" -> List(0, 1)), + JObject( + "class" -> classOf[JsonTestTreeNode].getName, + "num-children" -> 0, + "arg" -> "0"), + JObject( + "class" -> classOf[JsonTestTreeNode].getName, + "num-children" -> 0, + "arg" -> "1"))) + } + + test("toJSON should not throws java.lang.StackOverflowError") { + val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr)) + // Should not throw java.lang.StackOverflowError + udf.toJSON + } + + private def compareJSON(leftJson: String, rightJson: String): Unit = { + val left = JsonMethods.parse(leftJson) + val right = JsonMethods.parse(rightJson) + assert(left == right) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d361f61764d1fad32a609389def30000cf36d8a6..34fa626e00e319dba644fe390dd2ce9d6cb0e840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -120,7 +120,6 @@ abstract class QueryTest extends PlanTest { throw ae } } - checkJsonFormat(analyzedDS) assertEmptyMissingInput(analyzedDS) try ds.collect() catch { @@ -168,8 +167,6 @@ abstract class QueryTest extends PlanTest { } } - checkJsonFormat(analyzedDF) - assertEmptyMissingInput(analyzedDF) QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { @@ -228,139 +225,6 @@ abstract class QueryTest extends PlanTest { planWithCaching) } - private def checkJsonFormat(ds: Dataset[_]): Unit = { - // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that - // RDD and Data resolution does not break. - val logicalPlan = ds.queryExecution.analyzed - - // bypass some cases that we can't handle currently. - logicalPlan.transform { - case _: ObjectConsumer => return - case _: ObjectProducer => return - case _: AppendColumns => return - case _: TypedFilter => return - case _: LogicalRelation => return - case p if p.getClass.getSimpleName == "MetastoreRelation" => return - case _: MemoryPlan => return - case p: InMemoryRelation => - p.child.transform { - case _: ObjectConsumerExec => return - case _: ObjectProducerExec => return - } - p - }.transformAllExpressions { - case _: ImperativeAggregate => return - case _: TypedAggregateExpression => return - case Literal(_, _: ObjectType) => return - case _: UserDefinedGenerator => return - } - - // bypass hive tests before we fix all corner cases in hive module. - if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return - - val jsonString = try { - logicalPlan.toJSON - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to parse logical plan to JSON: - |${logicalPlan.treeString} - """.stripMargin, e) - } - - // scala function is not serializable to JSON, use null to replace them so that we can compare - // the plans later. - val normalized1 = logicalPlan.transformAllExpressions { - case udf: ScalaUDF => udf.copy(function = null) - case gen: UserDefinedGenerator => gen.copy(function = null) - // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove - // the Metadata from the normalized plan so that we can compare this plan with the - // JSON-deserialzed plan. - case a @ Alias(child, name) if a.explicitMetadata.isDefined => - Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated) - case a: AttributeReference if a.metadata != Metadata.empty => - AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier, - a.isGenerated) - } - - // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains - // these non-serializable stuff, and use these original ones to replace the null-placeholders - // in the logical plans parsed from JSON. - val logicalRDDs = new ArrayDeque[LogicalRDD]() - val localRelations = new ArrayDeque[LocalRelation]() - val inMemoryRelations = new ArrayDeque[InMemoryRelation]() - def collectData: (LogicalPlan => Unit) = { - case l: LogicalRDD => - logicalRDDs.offer(l) - case l: LocalRelation => - localRelations.offer(l) - case i: InMemoryRelation => - inMemoryRelations.offer(i) - case p => - p.expressions.foreach { - _.foreach { - case s: SubqueryExpression => - s.plan.foreach(collectData) - case _ => - } - } - } - logicalPlan.foreach(collectData) - - - val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to rebuild the logical plan from JSON: - |${logicalPlan.treeString} - | - |${logicalPlan.prettyJson} - """.stripMargin, e) - } - - def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { - case l: LogicalRDD => - val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(spark) - case l: LocalRelation => - val origin = localRelations.pop() - l.copy(data = origin.data) - case l: InMemoryRelation => - val origin = inMemoryRelations.pop() - InMemoryRelation( - l.output, - l.useCompression, - l.batchSize, - l.storageLevel, - origin.child, - l.tableName)( - origin.cachedColumnBuffers, - origin.batchStats) - case p => - p.transformExpressions { - case s: SubqueryExpression => - s.withNewPlan(s.plan.transformDown(renormalize)) - } - } - val normalized2 = jsonBackPlan.transformDown(renormalize) - - assert(logicalRDDs.isEmpty) - assert(localRelations.isEmpty) - assert(inMemoryRelations.isEmpty) - - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: the logical plan parsed from json does not match the original one === - |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - /** * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. */