From 721ced28b522cc00b45ca7fa32a99e80ad3de2f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Thu, 28 Jan 2016 22:42:43 -0800 Subject: [PATCH] [SPARK-13067] [SQL] workaround for a weird scala reflection problem A simple workaround to avoid getting parameter types when convert a logical plan to json. Author: Wenchen Fan <wenchen@databricks.com> Closes #10970 from cloud-fan/reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 ++++++++++++++++--- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 643228d0eb..e5811efb43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -601,6 +601,20 @@ object ScalaReflection extends ScalaReflection { getConstructorParameters(t) } + /** + * Returns the parameter names for the primary constructor of this class. + * + * Logically we should call `getConstructorParameters` and throw away the parameter types to get + * parameter names, however there are some weird scala reflection problems and this method is a + * workaround to avoid getting parameter types. + */ + def getConstructorParameterNames(cls: Class[_]): Seq[String] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + constructParams(t).map(_.name.toString) + } + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -745,6 +759,12 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe + constructParams(tpe).map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } + + protected def constructParams(tpe: Type): Seq[Symbol] = { val constructorSymbol = tpe.member(nme.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramss @@ -758,9 +778,6 @@ trait ScalaReflection { primaryConstructorSymbol.get.asMethod.paramss } } - - params.flatten.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } + params.flatten } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 57e1a3c9eb..2df0683f9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } protected def jsonFields: List[JField] = { - val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) @@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper case p: Product => try { - val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldNames = getConstructorParameterNames(p.getClass) val fieldValues = p.productIterator.toSeq assert(fieldNames.length == fieldValues.length) ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { -- GitLab