Skip to content
Snippets Groups Projects
Commit 8f932fb8 authored by Andrew Or's avatar Andrew Or Committed by Reynold Xin
Browse files

[SPARK-15234][SQL] Fix spark.catalog.listDatabases.show()

## What changes were proposed in this pull request?

Before:
```
scala> spark.catalog.listDatabases.show()
+--------------------+-----------+-----------+
|                name|description|locationUri|
+--------------------+-----------+-----------+
|Database[name='de...|
|Database[name='my...|
|Database[name='so...|
+--------------------+-----------+-----------+
```

After:
```
+-------+--------------------+--------------------+
|   name|         description|         locationUri|
+-------+--------------------+--------------------+
|default|Default Hive data...|file:/user/hive/w...|
|  my_db|  This is a database|file:/Users/andre...|
|some_db|                    |file:/private/var...|
+-------+--------------------+--------------------+
```

## How was this patch tested?

New test in `CatalogSuite`

Author: Andrew Or <andrew@databricks.com>

Closes #13015 from andrewor14/catalog-show.
parent 980bba0d
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
......@@ -659,6 +658,15 @@ object ScalaReflection extends ScalaReflection {
constructParams(t).map(_.name.toString)
}
/**
* Returns the parameter values for the primary constructor of this class.
*/
def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = {
getConstructorParameterNames(obj.getClass).map { name =>
obj.getClass.getMethod(name).invoke(obj)
}
}
/*
* Retrieves the runtime class corresponding to the provided type.
*/
......
......@@ -245,6 +245,8 @@ class Dataset[T] private[sql](
val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map {
case r: Row => r
case tuple: Product => Row.fromTuple(tuple)
case definedByCtor: DefinedByConstructorParams =>
Row.fromSeq(ScalaReflection.getConstructorParameterValues(definedByCtor))
case o => Row(o)
}.map { row =>
row.toSeq.map { cell =>
......
......@@ -50,14 +50,6 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}
private def makeDataset[T <: DefinedByConstructorParams: TypeTag](data: Seq[T]): Dataset[T] = {
val enc = ExpressionEncoder[T]()
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(enc.schema.toAttributes, encoded)
val queryExecution = sparkSession.executePlan(plan)
new Dataset[T](sparkSession, queryExecution, enc)
}
/**
* Returns the current default database in this session.
*/
......@@ -83,7 +75,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
description = metadata.description,
locationUri = metadata.locationUri)
}
makeDataset(databases)
CatalogImpl.makeDataset(databases, sparkSession)
}
/**
......@@ -111,7 +103,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"),
isTemporary = isTemp)
}
makeDataset(tables)
CatalogImpl.makeDataset(tables, sparkSession)
}
/**
......@@ -137,7 +129,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
className = metadata.getClassName,
isTemporary = funcIdent.database.isEmpty)
}
makeDataset(functions)
CatalogImpl.makeDataset(functions, sparkSession)
}
/**
......@@ -166,7 +158,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
isPartition = partitionColumnNames.contains(c.name),
isBucket = bucketColumnNames.contains(c.name))
}
makeDataset(columns)
CatalogImpl.makeDataset(columns, sparkSession)
}
/**
......@@ -350,3 +342,18 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}
private[sql] object CatalogImpl {
def makeDataset[T <: DefinedByConstructorParams: TypeTag](
data: Seq[T],
sparkSession: SparkSession): Dataset[T] = {
val enc = ExpressionEncoder[T]()
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(enc.schema.toAttributes, encoded)
val queryExecution = sparkSession.executePlan(plan)
new Dataset[T](sparkSession, queryExecution, enc)
}
}
......@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalog.{Column, Database, Function, Table}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.Range
......@@ -266,6 +266,30 @@ class CatalogSuite
"nullable='false', isPartition='true', isBucket='true']")
}
test("catalog classes format in Dataset.show") {
val db = new Database("nama", "descripta", "locata")
val table = new Table("nama", "databasa", "descripta", "typa", isTemporary = false)
val function = new Function("nama", "descripta", "classa", isTemporary = false)
val column = new Column(
"nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true)
val dbFields = ScalaReflection.getConstructorParameterValues(db)
val tableFields = ScalaReflection.getConstructorParameterValues(table)
val functionFields = ScalaReflection.getConstructorParameterValues(function)
val columnFields = ScalaReflection.getConstructorParameterValues(column)
assert(dbFields == Seq("nama", "descripta", "locata"))
assert(tableFields == Seq("nama", "databasa", "descripta", "typa", false))
assert(functionFields == Seq("nama", "descripta", "classa", false))
assert(columnFields == Seq("nama", "descripta", "typa", false, true, true))
val dbString = CatalogImpl.makeDataset(Seq(db), sparkSession).showString(10)
val tableString = CatalogImpl.makeDataset(Seq(table), sparkSession).showString(10)
val functionString = CatalogImpl.makeDataset(Seq(function), sparkSession).showString(10)
val columnString = CatalogImpl.makeDataset(Seq(column), sparkSession).showString(10)
dbFields.foreach { f => assert(dbString.contains(f.toString)) }
tableFields.foreach { f => assert(tableString.contains(f.toString)) }
functionFields.foreach { f => assert(functionString.contains(f.toString)) }
columnFields.foreach { f => assert(columnString.contains(f.toString)) }
}
// TODO: add tests for the rest of them
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment