diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d158a64a85bc080e93757d3eb7b0734fb238d965..79bb7a701baf8dc5f238701e7a5ed5a5da98812f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -659,6 +658,15 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } + /** + * Returns the parameter values for the primary constructor of this class. + */ + def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = { + getConstructorParameterNames(obj.getClass).map { name => + obj.getClass.getMethod(name).invoke(obj) + } + } + /* * Retrieves the runtime class corresponding to the provided type. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dd73fb8dad69518fdba4c80e2a1a344cabf9be9b..45a69cacd18c173c4455984eec5f03cf9115d910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -245,6 +245,8 @@ class Dataset[T] private[sql]( val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { case r: Row => r case tuple: Product => Row.fromTuple(tuple) + case definedByCtor: DefinedByConstructorParams => + Row.fromSeq(ScalaReflection.getConstructorParameterValues(definedByCtor)) case o => Row(o) }.map { row => row.toSeq.map { cell => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 976c9c53de13996ac7ac345e0b148b79d0ee62d2..d08dca32c043de3068d162326cb2932184e400e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -50,14 +50,6 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } - private def makeDataset[T <: DefinedByConstructorParams: TypeTag](data: Seq[T]): Dataset[T] = { - val enc = ExpressionEncoder[T]() - val encoded = data.map(d => enc.toRow(d).copy()) - val plan = new LocalRelation(enc.schema.toAttributes, encoded) - val queryExecution = sparkSession.executePlan(plan) - new Dataset[T](sparkSession, queryExecution, enc) - } - /** * Returns the current default database in this session. */ @@ -83,7 +75,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { description = metadata.description, locationUri = metadata.locationUri) } - makeDataset(databases) + CatalogImpl.makeDataset(databases, sparkSession) } /** @@ -111,7 +103,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), isTemporary = isTemp) } - makeDataset(tables) + CatalogImpl.makeDataset(tables, sparkSession) } /** @@ -137,7 +129,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { className = metadata.getClassName, isTemporary = funcIdent.database.isEmpty) } - makeDataset(functions) + CatalogImpl.makeDataset(functions, sparkSession) } /** @@ -166,7 +158,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { isPartition = partitionColumnNames.contains(c.name), isBucket = bucketColumnNames.contains(c.name)) } - makeDataset(columns) + CatalogImpl.makeDataset(columns, sparkSession) } /** @@ -350,3 +342,18 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } + + +private[sql] object CatalogImpl { + + def makeDataset[T <: DefinedByConstructorParams: TypeTag]( + data: Seq[T], + sparkSession: SparkSession): Dataset[T] = { + val enc = ExpressionEncoder[T]() + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(enc.schema.toAttributes, encoded) + val queryExecution = sparkSession.executePlan(plan) + new Dataset[T](sparkSession, queryExecution, enc) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 986d8f514f2fb83b22fd397ea270462d15d6584a..73c2076a302b0506a444007937f92c062e34d911 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalog.{Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.Range @@ -266,6 +266,30 @@ class CatalogSuite "nullable='false', isPartition='true', isBucket='true']") } + test("catalog classes format in Dataset.show") { + val db = new Database("nama", "descripta", "locata") + val table = new Table("nama", "databasa", "descripta", "typa", isTemporary = false) + val function = new Function("nama", "descripta", "classa", isTemporary = false) + val column = new Column( + "nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true) + val dbFields = ScalaReflection.getConstructorParameterValues(db) + val tableFields = ScalaReflection.getConstructorParameterValues(table) + val functionFields = ScalaReflection.getConstructorParameterValues(function) + val columnFields = ScalaReflection.getConstructorParameterValues(column) + assert(dbFields == Seq("nama", "descripta", "locata")) + assert(tableFields == Seq("nama", "databasa", "descripta", "typa", false)) + assert(functionFields == Seq("nama", "descripta", "classa", false)) + assert(columnFields == Seq("nama", "descripta", "typa", false, true, true)) + val dbString = CatalogImpl.makeDataset(Seq(db), sparkSession).showString(10) + val tableString = CatalogImpl.makeDataset(Seq(table), sparkSession).showString(10) + val functionString = CatalogImpl.makeDataset(Seq(function), sparkSession).showString(10) + val columnString = CatalogImpl.makeDataset(Seq(column), sparkSession).showString(10) + dbFields.foreach { f => assert(dbString.contains(f.toString)) } + tableFields.foreach { f => assert(tableString.contains(f.toString)) } + functionFields.foreach { f => assert(functionString.contains(f.toString)) } + columnFields.foreach { f => assert(columnString.contains(f.toString)) } + } + // TODO: add tests for the rest of them }