diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ada48eaf5dc0f5e031d6534b71114956b44c553a..5a55be1e515587794af64f4cb681929d32f3278e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -47,10 +47,13 @@ object ScalaReflection { val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< typeOf[Product] => - val params = t.member("<init>": TermName).asMethod.paramss + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss Schema(StructType( params.head.map { p => - val Schema(dataType, nullable) = schemaFor(p.typeSignature) + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) StructField(p.name.toString, dataType, nullable) }), nullable = true) // Need to decide if we actually need a special type here. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 489d7e9c2437fcfa7ed21cede46af37eef77cc98..c0438dbe52a479994ec1d7afc639e9d556220649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,6 +60,9 @@ case class ComplexData( mapField: Map[Int, String], structField: PrimitiveData) +case class GenericData[A]( + genericField: A) + class ScalaReflectionSuite extends FunSuite { import ScalaReflection._ @@ -128,4 +131,21 @@ class ScalaReflectionSuite extends FunSuite { nullable = true))), nullable = true)) } + + test("generic data") { + val schema = schemaFor[GenericData[Int]] + assert(schema === Schema( + StructType(Seq( + StructField("genericField", IntegerType, nullable = false))), + nullable = true)) + } + + test("tuple data") { + val schema = schemaFor[(Int, String)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true))), + nullable = true)) + } }