diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 893af5146c5b3f2b4e187b858fdeffb88cb83431..83cb3755258324f2e9193a7c6fd8be70dec3f1f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.ScalaReflection._
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       // this child in all children.
       case (name, value: TreeNode[_]) if containsChild(value) =>
         name -> JInt(children.indexOf(value))
-      case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+      case (name, value: Seq[BaseType]) if value.forall(containsChild) =>
         name -> JArray(
           value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
         )
@@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     // SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming
     // it to JSON may trigger OutOfMemoryError.
     case m: Metadata => Metadata.empty.jsonValue
+    case clazz: Class[_] => JString(clazz.getName)
     case s: StorageLevel =>
       ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
         ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
     case n: TreeNode[_] => n.jsonValue
     case o: Option[_] => o.map(parseToJson)
-    case t: Seq[_] => JArray(t.map(parseToJson).toList)
-    case m: Map[_, _] =>
-      val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
-      JObject(fields)
-    case r: RDD[_] => JNothing
+    // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType]
+    case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) ||
+      t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) =>
+      JArray(t.map(parseToJson).toList)
+    case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] =>
+      JString(Utils.truncatedString(t, "[", ", ", "]"))
+    case t: Seq[_] => JNull
+    case m: Map[_, _] => JNull
     // if it's a scala object, we can simply keep the full class path.
     // TODO: currently if the class name ends with "$", we think it's a scala object, there is
     // probably a better way to check it.
     case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
-    // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
-    case p: Product => try {
-      val fieldNames = getConstructorParameterNames(p.getClass)
-      val fieldValues = p.productIterator.toSeq
-      assert(fieldNames.length == fieldValues.length)
-      ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
-        case (name, value) => name -> parseToJson(value)
-      }.toList
-    } catch {
-      case _: RuntimeException => null
-    }
-    case _ => JNull
-  }
-}
-
-object TreeNode {
-  def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
-    val jsonAST = parse(json)
-    assert(jsonAST.isInstanceOf[JArray])
-    reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
-  }
-
-  private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
-    assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
-    val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
-
-    def parseNextNode(): TreeNode[_] = {
-      val nextNode = jsonNodes.pop()
-
-      val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
-      if (cls == classOf[Literal]) {
-        Literal.fromJSON(nextNode)
-      } else if (cls.getName.endsWith("$")) {
-        cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
-      } else {
-        val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
-
-        val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
-        val fields = getConstructorParameters(cls)
-
-        val parameters: Array[AnyRef] = fields.map {
-          case (fieldName, fieldType) =>
-            parseFromJson(nextNode \ fieldName, fieldType, children, sc)
-        }.toArray
-
-        val maybeCtor = cls.getConstructors.find { p =>
-          val expectedTypes = p.getParameterTypes
-          expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
-            case (cls, tpe) => cls == getClassFromType(tpe)
-          }
-        }
-        if (maybeCtor.isEmpty) {
-          sys.error(s"No valid constructor for ${cls.getName}")
-        } else {
-          try {
-            maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
-          } catch {
-            case e: java.lang.IllegalArgumentException =>
-              throw new RuntimeException(
-                s"""
-                  |Failed to construct tree node: ${cls.getName}
-                  |ctor: ${maybeCtor.get}
-                  |types: ${parameters.map(_.getClass).mkString(", ")}
-                  |args: ${parameters.mkString(", ")}
-                """.stripMargin, e)
-          }
-        }
-      }
-    }
-
-    parseNextNode()
-  }
-
-  import universe._
-
-  private def parseFromJson(
-      value: JValue,
-      expectedType: Type,
-      children: Seq[TreeNode[_]],
-      sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
-    if (value == JNull) return null
-
-    expectedType match {
-      case t if t <:< definitions.BooleanTpe =>
-        value.asInstanceOf[JBool].value: java.lang.Boolean
-      case t if t <:< definitions.ByteTpe =>
-        value.asInstanceOf[JInt].num.toByte: java.lang.Byte
-      case t if t <:< definitions.ShortTpe =>
-        value.asInstanceOf[JInt].num.toShort: java.lang.Short
-      case t if t <:< definitions.IntTpe =>
-        value.asInstanceOf[JInt].num.toInt: java.lang.Integer
-      case t if t <:< definitions.LongTpe =>
-        value.asInstanceOf[JInt].num.toLong: java.lang.Long
-      case t if t <:< definitions.FloatTpe =>
-        value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
-      case t if t <:< definitions.DoubleTpe =>
-        value.asInstanceOf[JDouble].num: java.lang.Double
-
-      case t if t <:< localTypeOf[java.lang.Boolean] =>
-        value.asInstanceOf[JBool].value: java.lang.Boolean
-      case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
-      case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
-      case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
-      case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
-      case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
-      case t if t <:< localTypeOf[StorageLevel] =>
-        val JBool(useDisk) = value \ "useDisk"
-        val JBool(useMemory) = value \ "useMemory"
-        val JBool(useOffHeap) = value \ "useOffHeap"
-        val JBool(deserialized) = value \ "deserialized"
-        val JInt(replication) = value \ "replication"
-        StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
-      case t if t <:< localTypeOf[TreeNode[_]] => value match {
-        case JInt(i) => children(i.toInt)
-        case arr: JArray => reconstruct(arr, sc)
-        case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+    case p: Product if shouldConvertToJson(p) =>
+      try {
+        val fieldNames = getConstructorParameterNames(p.getClass)
+        val fieldValues = p.productIterator.toSeq
+        assert(fieldNames.length == fieldValues.length)
+        ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+          case (name, value) => name -> parseToJson(value)
+        }.toList
+      } catch {
+        case _: RuntimeException => null
       }
-      case t if t <:< localTypeOf[Option[_]] =>
-        if (value == JNothing) {
-          None
-        } else {
-          val TypeRef(_, _, Seq(optType)) = t
-          Option(parseFromJson(value, optType, children, sc))
-        }
-      case t if t <:< localTypeOf[Seq[_]] =>
-        val TypeRef(_, _, Seq(elementType)) = t
-        val JArray(elements) = value
-        elements.map(parseFromJson(_, elementType, children, sc)).toSeq
-      case t if t <:< localTypeOf[Map[_, _]] =>
-        val TypeRef(_, _, Seq(keyType, valueType)) = t
-        val JObject(fields) = value
-        fields.map {
-          case (name, value) => name -> parseFromJson(value, valueType, children, sc)
-        }.toMap
-      case t if t <:< localTypeOf[RDD[_]] =>
-        new EmptyRDD[Any](sc)
-      case _ if isScalaObject(value) =>
-        val JString(clsName) = value \ "object"
-        val cls = Utils.classForName(clsName)
-        cls.getField("MODULE$").get(cls)
-      case t if t <:< localTypeOf[Product] =>
-        val fields = getConstructorParameters(t)
-        val clsName = getClassNameFromType(t)
-        parseToProduct(clsName, fields, value, children, sc)
-      // There maybe some cases that the parameter type signature is not Product but the value is,
-      // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
-      case _ if isScalaProduct(value) =>
-        val JString(clsName) = value \ "product-class"
-        val fields = getConstructorParameters(Utils.classForName(clsName))
-        parseToProduct(clsName, fields, value, children, sc)
-      case _ => sys.error(s"Do not support type $expectedType with json $value.")
-    }
-  }
-
-  private def parseToProduct(
-      clsName: String,
-      fields: Seq[(String, Type)],
-      value: JValue,
-      children: Seq[TreeNode[_]],
-      sc: SparkContext): AnyRef = {
-    val parameters: Array[AnyRef] = fields.map {
-      case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
-    }.toArray
-    val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
-    ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
-  }
-
-  private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
-    case JString(str) if str.endsWith("$") => true
-    case _ => false
+    case _ => JNull
   }
 
-  private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
-    case _: JString => true
+  private def shouldConvertToJson(product: Product): Boolean = product match {
+    case exprId: ExprId => true
+    case field: StructField => true
+    case id: TableIdentifier => true
+    case join: JoinType => true
+    case id: FunctionIdentifier => true
+    case spec: BucketSpec => true
+    case catalog: CatalogTable => true
+    case boundary: FrameBoundary => true
+    case frame: WindowFrame => true
+    case partition: Partitioning => true
+    case resource: FunctionResource => true
+    case broadcast: BroadcastMode => true
+    case table: CatalogTableType => true
+    case storage: CatalogStorageFormat => true
     case _ => false
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6246380dbeb9b9cc8b436984ccd22a5dcea79c2a..cb0426c7a98a18fbe7d01efa6d11cb32267e2be7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -17,13 +17,29 @@
 
 package org.apache.spark.sql.catalyst.trees
 
+import java.math.BigInteger
+import java.util.UUID
+
 import scala.collection.mutable.ArrayBuffer
 
+import org.json4s.jackson.JsonMethods
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
+import org.apache.spark.sql.catalyst.dsl.expressions.DslString
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{IntegerType, NullType, StringType}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
+import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.storage.StorageLevel
 
 case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
   override def children: Seq[Expression] = optKey.toSeq
@@ -45,6 +61,20 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Expression with
   override lazy val resolved = true
 }
 
+case class JsonTestTreeNode(arg: Any) extends LeafNode {
+  override def output: Seq[Attribute] = Seq.empty[Attribute]
+}
+
+case class NameValue(name: String, value: Any)
+
+case object DummyObject
+
+case class SelfReferenceUDF(
+    var config: Map[String, Any] = Map.empty[String, Any]) extends Function1[String, Boolean] {
+  config += "self" -> this
+  def apply(key: String): Boolean = config.contains(key)
+}
+
 class TreeNodeSuite extends SparkFunSuite {
   test("top node changed") {
     val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -261,4 +291,264 @@ class TreeNodeSuite extends SparkFunSuite {
       assert(actual === expected)
     }
   }
+
+  test("toJSON") {
+    def assertJSON(input: Any, json: JValue): Unit = {
+      val expected =
+        s"""
+           |[{
+           |  "class": "${classOf[JsonTestTreeNode].getName}",
+           |  "num-children": 0,
+           |  "arg": ${compact(render(json))}
+           |}]
+         """.stripMargin
+      compareJSON(JsonTestTreeNode(input).toJSON, expected)
+    }
+
+    // Converts simple types to JSON
+    assertJSON(true, true)
+    assertJSON(33.toByte, 33)
+    assertJSON(44, 44)
+    assertJSON(55L, 55L)
+    assertJSON(3.0, 3.0)
+    assertJSON(4.0D, 4.0D)
+    assertJSON(BigInt(BigInteger.valueOf(88L)), 88L)
+    assertJSON(null, JNull)
+    assertJSON("text", "text")
+    assertJSON(Some("text"), "text")
+    compareJSON(JsonTestTreeNode(None).toJSON,
+      s"""[
+         |  {
+         |    "class": "${classOf[JsonTestTreeNode].getName}",
+         |    "num-children": 0
+         |  }
+         |]
+       """.stripMargin)
+
+    val uuid = UUID.randomUUID()
+    assertJSON(uuid, uuid.toString)
+
+    // Converts Spark Sql DataType to JSON
+    assertJSON(IntegerType, "integer")
+    assertJSON(Metadata.empty, JObject(Nil))
+    assertJSON(
+      StorageLevel.NONE,
+      JObject(
+        "useDisk" -> false,
+        "useMemory" -> false,
+        "useOffHeap" -> false,
+        "deserialized" -> false,
+        "replication" -> 1)
+    )
+
+    // Converts TreeNode argument to JSON
+    assertJSON(
+      Literal(333),
+      List(
+        JObject(
+          "class" -> classOf[Literal].getName,
+          "num-children" -> 0,
+          "value" -> "333",
+          "dataType" -> "integer")))
+
+    // Converts Seq[String] to JSON
+    assertJSON(Seq("1", "2", "3"), "[1, 2, 3]")
+
+    // Converts Seq[DataType] to JSON
+    assertJSON(Seq(IntegerType, DoubleType, FloatType), List("integer", "double", "float"))
+
+    // Converts Seq[Partitioning] to JSON
+    assertJSON(
+      Seq(SinglePartition, RoundRobinPartitioning(numPartitions = 3)),
+      List(
+        JObject("object" -> JString(SinglePartition.getClass.getName)),
+        JObject(
+          "product-class" -> classOf[RoundRobinPartitioning].getName,
+          "numPartitions" -> 3)))
+
+    // Converts case object to JSON
+    assertJSON(DummyObject, JObject("object" -> JString(DummyObject.getClass.getName)))
+
+    // Converts ExprId to JSON
+    assertJSON(
+      ExprId(0, uuid),
+      JObject(
+        "product-class" -> classOf[ExprId].getName,
+        "id" -> 0,
+        "jvmId" -> uuid.toString))
+
+    // Converts StructField to JSON
+    assertJSON(
+      StructField("field", IntegerType),
+      JObject(
+        "product-class" -> classOf[StructField].getName,
+        "name" -> "field",
+        "dataType" -> "integer",
+        "nullable" -> true,
+        "metadata" -> JObject(Nil)))
+
+    // Converts TableIdentifier to JSON
+    assertJSON(
+      TableIdentifier("table"),
+      JObject(
+        "product-class" -> classOf[TableIdentifier].getName,
+        "table" -> "table"))
+
+    // Converts JoinType to JSON
+    assertJSON(
+      NaturalJoin(LeftOuter),
+      JObject(
+        "product-class" -> classOf[NaturalJoin].getName,
+        "tpe" -> JObject("object" -> JString(LeftOuter.getClass.getName))))
+
+    // Converts FunctionIdentifier to JSON
+    assertJSON(
+      FunctionIdentifier("function", None),
+      JObject(
+        "product-class" -> JString(classOf[FunctionIdentifier].getName),
+          "funcName" -> "function"))
+
+    // Converts BucketSpec to JSON
+    assertJSON(
+      BucketSpec(1, Seq("bucket"), Seq("sort")),
+      JObject(
+        "product-class" -> classOf[BucketSpec].getName,
+        "numBuckets" -> 1,
+        "bucketColumnNames" -> "[bucket]",
+        "sortColumnNames" -> "[sort]"))
+
+    // Converts FrameBoundary to JSON
+    assertJSON(
+      ValueFollowing(3),
+      JObject(
+        "product-class" -> classOf[ValueFollowing].getName,
+        "value" -> 3))
+
+    // Converts WindowFrame to JSON
+    assertJSON(
+      SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow),
+      JObject(
+        "product-class" -> classOf[SpecifiedWindowFrame].getName,
+        "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)),
+        "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)),
+        "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName))))
+
+    // Converts Partitioning to JSON
+    assertJSON(
+      RoundRobinPartitioning(numPartitions = 3),
+      JObject(
+        "product-class" -> classOf[RoundRobinPartitioning].getName,
+        "numPartitions" -> 3))
+
+    // Converts FunctionResource to JSON
+    assertJSON(
+      FunctionResource(JarResource, "file:///"),
+      JObject(
+        "product-class" -> JString(classOf[FunctionResource].getName),
+        "resourceType" -> JObject("object" -> JString(JarResource.getClass.getName)),
+        "uri" -> "file:///"))
+
+    // Converts BroadcastMode to JSON
+    assertJSON(
+      IdentityBroadcastMode,
+      JObject("object" -> JString(IdentityBroadcastMode.getClass.getName)))
+
+    // Converts CatalogTable to JSON
+    assertJSON(
+      CatalogTable(
+        TableIdentifier("table"),
+        CatalogTableType.MANAGED,
+        CatalogStorageFormat.empty,
+        StructType(StructField("a", IntegerType, true) :: Nil),
+        createTime = 0L),
+
+      JObject(
+        "product-class" -> classOf[CatalogTable].getName,
+        "identifier" -> JObject(
+          "product-class" -> classOf[TableIdentifier].getName,
+          "table" -> "table"
+        ),
+        "tableType" -> JObject(
+          "product-class" -> classOf[CatalogTableType].getName,
+          "name" -> "MANAGED"
+        ),
+        "storage" -> JObject(
+          "product-class" -> classOf[CatalogStorageFormat].getName,
+          "compressed" -> false,
+          "properties" -> JNull
+        ),
+        "schema" -> JObject(
+          "type" -> "struct",
+          "fields" -> List(
+            JObject(
+              "name" -> "a",
+              "type" -> "integer",
+              "nullable" -> true,
+              "metadata" -> JObject(Nil)))),
+        "partitionColumnNames" -> List.empty[String],
+        "owner" -> "",
+        "createTime" -> 0,
+        "lastAccessTime" -> -1,
+        "properties" -> JNull,
+        "unsupportedFeatures" -> List.empty[String]))
+
+    // For unknown case class, returns JNull.
+    val bigValue = new Array[Int](10000)
+    assertJSON(NameValue("name", bigValue), JNull)
+
+    // Converts Seq[TreeNode] to JSON recursively
+    assertJSON(
+      Seq(Literal(1), Literal(2)),
+      List(
+        List(
+          JObject(
+            "class" -> JString(classOf[Literal].getName),
+            "num-children" -> 0,
+            "value" -> "1",
+            "dataType" -> "integer")),
+        List(
+          JObject(
+            "class" -> JString(classOf[Literal].getName),
+            "num-children" -> 0,
+            "value" -> "2",
+            "dataType" -> "integer"))))
+
+    // Other Seq is converted to JNull, to reduce the risk of out of memory
+    assertJSON(Seq(1, 2, 3), JNull)
+
+    // All Map type is converted to JNull, to reduce the risk of out of memory
+    assertJSON(Map("key" -> "value"), JNull)
+
+    // Unknown type is converted to JNull, to reduce the risk of out of memory
+    assertJSON(new Object {}, JNull)
+
+    // Convert all TreeNode children to JSON
+    assertJSON(
+      Union(Seq(JsonTestTreeNode("0"), JsonTestTreeNode("1"))),
+      List(
+        JObject(
+          "class" -> classOf[Union].getName,
+          "num-children" -> 2,
+          "children" -> List(0, 1)),
+        JObject(
+          "class" -> classOf[JsonTestTreeNode].getName,
+          "num-children" -> 0,
+          "arg" -> "0"),
+        JObject(
+          "class" -> classOf[JsonTestTreeNode].getName,
+          "num-children" -> 0,
+          "arg" -> "1")))
+  }
+
+  test("toJSON should not throws java.lang.StackOverflowError") {
+    val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr))
+    // Should not throw java.lang.StackOverflowError
+    udf.toJSON
+  }
+
+  private def compareJSON(leftJson: String, rightJson: String): Unit = {
+    val left = JsonMethods.parse(leftJson)
+    val right = JsonMethods.parse(rightJson)
+    assert(left == right)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d361f61764d1fad32a609389def30000cf36d8a6..34fa626e00e319dba644fe390dd2ce9d6cb0e840 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -120,7 +120,6 @@ abstract class QueryTest extends PlanTest {
           throw ae
         }
     }
-    checkJsonFormat(analyzedDS)
     assertEmptyMissingInput(analyzedDS)
 
     try ds.collect() catch {
@@ -168,8 +167,6 @@ abstract class QueryTest extends PlanTest {
         }
     }
 
-    checkJsonFormat(analyzedDF)
-
     assertEmptyMissingInput(analyzedDF)
 
     QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
@@ -228,139 +225,6 @@ abstract class QueryTest extends PlanTest {
         planWithCaching)
   }
 
-  private def checkJsonFormat(ds: Dataset[_]): Unit = {
-    // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
-    // RDD and Data resolution does not break.
-    val logicalPlan = ds.queryExecution.analyzed
-
-    // bypass some cases that we can't handle currently.
-    logicalPlan.transform {
-      case _: ObjectConsumer => return
-      case _: ObjectProducer => return
-      case _: AppendColumns => return
-      case _: TypedFilter => return
-      case _: LogicalRelation => return
-      case p if p.getClass.getSimpleName == "MetastoreRelation" => return
-      case _: MemoryPlan => return
-      case p: InMemoryRelation =>
-        p.child.transform {
-          case _: ObjectConsumerExec => return
-          case _: ObjectProducerExec => return
-        }
-        p
-    }.transformAllExpressions {
-      case _: ImperativeAggregate => return
-      case _: TypedAggregateExpression => return
-      case Literal(_, _: ObjectType) => return
-      case _: UserDefinedGenerator => return
-    }
-
-    // bypass hive tests before we fix all corner cases in hive module.
-    if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
-
-    val jsonString = try {
-      logicalPlan.toJSON
-    } catch {
-      case NonFatal(e) =>
-        fail(
-          s"""
-             |Failed to parse logical plan to JSON:
-             |${logicalPlan.treeString}
-           """.stripMargin, e)
-    }
-
-    // scala function is not serializable to JSON, use null to replace them so that we can compare
-    // the plans later.
-    val normalized1 = logicalPlan.transformAllExpressions {
-      case udf: ScalaUDF => udf.copy(function = null)
-      case gen: UserDefinedGenerator => gen.copy(function = null)
-      // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove
-      // the Metadata from the normalized plan so that we can compare this plan with the
-      // JSON-deserialzed plan.
-      case a @ Alias(child, name) if a.explicitMetadata.isDefined =>
-        Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated)
-      case a: AttributeReference if a.metadata != Metadata.empty =>
-        AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier,
-          a.isGenerated)
-    }
-
-    // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
-    // these non-serializable stuff, and use these original ones to replace the null-placeholders
-    // in the logical plans parsed from JSON.
-    val logicalRDDs = new ArrayDeque[LogicalRDD]()
-    val localRelations = new ArrayDeque[LocalRelation]()
-    val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
-    def collectData: (LogicalPlan => Unit) = {
-      case l: LogicalRDD =>
-        logicalRDDs.offer(l)
-      case l: LocalRelation =>
-        localRelations.offer(l)
-      case i: InMemoryRelation =>
-        inMemoryRelations.offer(i)
-      case p =>
-        p.expressions.foreach {
-          _.foreach {
-            case s: SubqueryExpression =>
-              s.plan.foreach(collectData)
-            case _ =>
-          }
-        }
-    }
-    logicalPlan.foreach(collectData)
-
-
-    val jsonBackPlan = try {
-      TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
-    } catch {
-      case NonFatal(e) =>
-        fail(
-          s"""
-             |Failed to rebuild the logical plan from JSON:
-             |${logicalPlan.treeString}
-             |
-             |${logicalPlan.prettyJson}
-           """.stripMargin, e)
-    }
-
-    def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
-      case l: LogicalRDD =>
-        val origin = logicalRDDs.pop()
-        LogicalRDD(l.output, origin.rdd)(spark)
-      case l: LocalRelation =>
-        val origin = localRelations.pop()
-        l.copy(data = origin.data)
-      case l: InMemoryRelation =>
-        val origin = inMemoryRelations.pop()
-        InMemoryRelation(
-          l.output,
-          l.useCompression,
-          l.batchSize,
-          l.storageLevel,
-          origin.child,
-          l.tableName)(
-          origin.cachedColumnBuffers,
-          origin.batchStats)
-      case p =>
-        p.transformExpressions {
-          case s: SubqueryExpression =>
-            s.withNewPlan(s.plan.transformDown(renormalize))
-        }
-    }
-    val normalized2 = jsonBackPlan.transformDown(renormalize)
-
-    assert(logicalRDDs.isEmpty)
-    assert(localRelations.isEmpty)
-    assert(inMemoryRelations.isEmpty)
-
-    if (normalized1 != normalized2) {
-      fail(
-        s"""
-           |== FAIL: the logical plan parsed from json does not match the original one ===
-           |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
-          """.stripMargin)
-    }
-  }
-
   /**
    * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
    */