Skip to content
Snippets Groups Projects
Commit c4d5ad80 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-13282][SQL] LogicalPlan toSql should just return a String

Previously we were using Option[String] and None to indicate the case when Spark fails to generate SQL. It is easier to just use exceptions to propagate error cases, rather than having for comprehension everywhere. I also introduced a "build" function that simplifies string concatenation (i.e. no need to reason about whether we have an extra space or not).

Author: Reynold Xin <rxin@databricks.com>

Closes #11171 from rxin/SPARK-13282.
parent 5b805df2
No related branches found
No related tags found
No related merge requests found
...@@ -19,10 +19,12 @@ package org.apache.spark.sql.hive ...@@ -19,10 +19,12 @@ package org.apache.spark.sql.hive
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import scala.util.control.NonFatal
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
...@@ -37,16 +39,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation ...@@ -37,16 +39,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
def toSQL: Option[String] = { def toSQL: String = {
val canonicalizedPlan = Canonicalizer.execute(logicalPlan) val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
val maybeSQL = try { try {
toSQL(canonicalizedPlan) val generatedSQL = toSQL(canonicalizedPlan)
} catch { case cause: UnsupportedOperationException =>
logInfo(s"Failed to build SQL query string because: ${cause.getMessage}")
None
}
if (maybeSQL.isDefined) {
logDebug( logDebug(
s"""Built SQL query string successfully from given logical plan: s"""Built SQL query string successfully from given logical plan:
| |
...@@ -54,10 +50,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ...@@ -54,10 +50,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|${logicalPlan.treeString} |${logicalPlan.treeString}
|# Canonicalized logical plan: |# Canonicalized logical plan:
|${canonicalizedPlan.treeString} |${canonicalizedPlan.treeString}
|# Built SQL query string: |# Generated SQL:
|${maybeSQL.get} |$generatedSQL
""".stripMargin) """.stripMargin)
} else { generatedSQL
} catch { case NonFatal(e) =>
logDebug( logDebug(
s"""Failed to build SQL query string from given logical plan: s"""Failed to build SQL query string from given logical plan:
| |
...@@ -66,128 +63,113 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ...@@ -66,128 +63,113 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|# Canonicalized logical plan: |# Canonicalized logical plan:
|${canonicalizedPlan.treeString} |${canonicalizedPlan.treeString}
""".stripMargin) """.stripMargin)
throw e
} }
maybeSQL
} }
private def projectToSQL( private def toSQL(node: LogicalPlan): String = node match {
projectList: Seq[NamedExpression], case Distinct(p: Project) =>
child: LogicalPlan, projectToSQL(p, isDistinct = true)
isDistinct: Boolean): Option[String] = {
for {
childSQL <- toSQL(child)
listSQL = projectList.map(_.sql).mkString(", ")
maybeFrom = child match {
case OneRowRelation => " "
case _ => " FROM "
}
distinct = if (isDistinct) " DISTINCT " else " "
} yield s"SELECT$distinct$listSQL$maybeFrom$childSQL"
}
private def aggregateToSQL( case p: Project =>
groupingExprs: Seq[Expression], projectToSQL(p, isDistinct = false)
aggExprs: Seq[Expression],
child: LogicalPlan): Option[String] = {
val aggSQL = aggExprs.map(_.sql).mkString(", ")
val groupingSQL = groupingExprs.map(_.sql).mkString(", ")
val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY "
val maybeFrom = child match {
case OneRowRelation => " "
case _ => " FROM "
}
toSQL(child).map { childSQL => case p: Aggregate =>
s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" aggregateToSQL(p)
}
}
private def toSQL(node: LogicalPlan): Option[String] = node match { case p: Limit =>
case Distinct(Project(list, child)) => s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}"
projectToSQL(list, child, isDistinct = true)
case p: Filter =>
case Project(list, child) => val whereOrHaving = p.child match {
projectToSQL(list, child, isDistinct = false) case _: Aggregate => "HAVING"
case _ => "WHERE"
case Aggregate(groupingExprs, aggExprs, child) => }
aggregateToSQL(groupingExprs, aggExprs, child) build(toSQL(p.child), whereOrHaving, p.condition.sql)
case Limit(limit, child) => case p: Union if p.children.length > 1 =>
for { val childrenSql = p.children.map(toSQL(_))
childSQL <- toSQL(child) childrenSql.mkString(" UNION ALL ")
limitSQL = limit.sql
} yield s"$childSQL LIMIT $limitSQL" case p: Subquery =>
p.child match {
case Filter(condition, child) => // Persisted data source relation
for { case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
childSQL <- toSQL(child) s"`$database`.`$table`"
whereOrHaving = child match { // Parentheses is not used for persisted data source relations
case _: Aggregate => "HAVING" // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case _ => "WHERE" case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
} build(toSQL(p.child), "AS", p.alias)
conditionSQL = condition.sql case _ =>
} yield s"$childSQL $whereOrHaving $conditionSQL" build("(" + toSQL(p.child) + ")", "AS", p.alias)
case Union(children) if children.length > 1 =>
val childrenSql = children.map(toSQL(_))
if (childrenSql.exists(_.isEmpty)) {
None
} else {
Some(childrenSql.map(_.get).mkString(" UNION ALL "))
} }
// Persisted data source relation case p: Join =>
case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => build(
Some(s"`$database`.`$table`") toSQL(p.left),
p.joinType.sql,
case Subquery(alias, child) => "JOIN",
toSQL(child).map( childSQL => toSQL(p.right),
child match { p.condition.map(" ON " + _.sql).getOrElse(""))
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 case p: MetastoreRelation =>
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => build(
s"$childSQL AS $alias" s"`${p.databaseName}`.`${p.tableName}`",
case _ => p.alias.map(a => s" AS `$a`").getOrElse("")
s"($childSQL) AS $alias" )
})
case Join(left, right, joinType, condition) =>
for {
leftSQL <- toSQL(left)
rightSQL <- toSQL(right)
joinTypeSQL = joinType.sql
conditionSQL = condition.map(" ON " + _.sql).getOrElse("")
} yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL"
case MetastoreRelation(database, table, alias) =>
val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("")
Some(s"`$database`.`$table`$aliasSQL")
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
if orders.map(_.child) == partitionExprs => if orders.map(_.child) == partitionExprs =>
for { build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
childSQL <- toSQL(child)
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") case p: Sort =>
} yield s"$childSQL CLUSTER BY $partitionExprsSQL" build(
toSQL(p.child),
case Sort(orders, global, child) => if (p.global) "ORDER BY" else "SORT BY",
for { p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
childSQL <- toSQL(child) )
ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
orderOrSort = if (global) "ORDER" else "SORT" case p: RepartitionByExpression =>
} yield s"$childSQL $orderOrSort BY $ordersSQL" build(
toSQL(p.child),
case RepartitionByExpression(partitionExprs, child, _) => "DISTRIBUTE BY",
for { p.partitionExpressions.map(_.sql).mkString(", ")
childSQL <- toSQL(child) )
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
} yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL"
case OneRowRelation => case OneRowRelation =>
Some("") ""
case _ => None case _ =>
throw new UnsupportedOperationException(s"unsupported plan $node")
}
/**
* Turns a bunch of string segments into a single string and separate each segment by a space.
* The segments are trimmed so only a single space appears in the separation.
* For example, `build("a", " b ", " c")` becomes "a b c".
*/
private def build(segments: String*): String = segments.map(_.trim).mkString(" ")
private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
"SELECT",
if (isDistinct) "DISTINCT" else "",
plan.projectList.map(_.sql).mkString(", "),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child)
)
}
private def aggregateToSQL(plan: Aggregate): String = {
val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
build(
"SELECT",
plan.aggregateExpressions.map(_.sql).mkString(", "),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child),
if (groupingSQL.isEmpty) "" else "GROUP BY",
groupingSQL
)
} }
object Canonicalizer extends RuleExecutor[LogicalPlan] { object Canonicalizer extends RuleExecutor[LogicalPlan] {
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.execution package org.apache.spark.sql.hive.execution
import scala.util.control.NonFatal
import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.expressions.Alias
...@@ -72,7 +74,9 @@ private[hive] case class CreateViewAsSelect( ...@@ -72,7 +74,9 @@ private[hive] case class CreateViewAsSelect(
private def prepareTable(sqlContext: SQLContext): HiveTable = { private def prepareTable(sqlContext: SQLContext): HiveTable = {
val expandedText = if (sqlContext.conf.canonicalView) { val expandedText = if (sqlContext.conf.canonicalView) {
rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect) try rebuildViewQueryString(sqlContext) catch {
case NonFatal(e) => wrapViewTextWithSelect
}
} else { } else {
wrapViewTextWithSelect wrapViewTextWithSelect
} }
...@@ -112,7 +116,7 @@ private[hive] case class CreateViewAsSelect( ...@@ -112,7 +116,7 @@ private[hive] case class CreateViewAsSelect(
s"SELECT $viewOutput FROM ($viewText) $viewName" s"SELECT $viewOutput FROM ($viewText) $viewName"
} }
private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = { private def rebuildViewQueryString(sqlContext: SQLContext): String = {
val logicalPlan = if (tableDesc.schema.isEmpty) { val logicalPlan = if (tableDesc.schema.isEmpty) {
child child
} else { } else {
......
...@@ -33,8 +33,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { ...@@ -33,8 +33,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
checkSQL(Literal(2.5D), "2.5") checkSQL(Literal(2.5D), "2.5")
checkSQL( checkSQL(
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
"TIMESTAMP('2016-01-01 00:00:00.0')")
// TODO tests for decimals // TODO tests for decimals
} }
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive package org.apache.spark.sql.hive
import scala.util.control.NonFatal
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.test.SQLTestUtils
...@@ -46,29 +48,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ...@@ -46,29 +48,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
private def checkHiveQl(hiveQl: String): Unit = { private def checkHiveQl(hiveQl: String): Unit = {
val df = sql(hiveQl) val df = sql(hiveQl)
val convertedSQL = new SQLBuilder(df).toSQL
if (convertedSQL.isEmpty) { val convertedSQL = try new SQLBuilder(df).toSQL catch {
fail( case NonFatal(e) =>
s"""Cannot convert the following HiveQL query plan back to SQL query string: fail(
| s"""Cannot convert the following HiveQL query plan back to SQL query string:
|# Original HiveQL query string: |
|$hiveQl |# Original HiveQL query string:
| |$hiveQl
|# Resolved query plan: |
|${df.queryExecution.analyzed.treeString} |# Resolved query plan:
""".stripMargin) |${df.queryExecution.analyzed.treeString}
""".stripMargin)
} }
val sqlString = convertedSQL.get
try { try {
checkAnswer(sql(sqlString), df) checkAnswer(sql(convertedSQL), df)
} catch { case cause: Throwable => } catch { case cause: Throwable =>
fail( fail(
s"""Failed to execute converted SQL string or got wrong answer: s"""Failed to execute converted SQL string or got wrong answer:
| |
|# Converted SQL query string: |# Converted SQL query string:
|$sqlString |$convertedSQL
| |
|# Original HiveQL query string: |# Original HiveQL query string:
|$hiveQl |$hiveQl
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive package org.apache.spark.sql.hive
import scala.util.control.NonFatal
import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
...@@ -40,9 +42,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { ...@@ -40,9 +42,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
} }
protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL val generatedSQL = try new SQLBuilder(plan, hiveContext).toSQL catch { case NonFatal(e) =>
if (maybeSQL.isEmpty) {
fail( fail(
s"""Cannot convert the following logical query plan to SQL: s"""Cannot convert the following logical query plan to SQL:
| |
...@@ -50,10 +50,8 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { ...@@ -50,10 +50,8 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin) """.stripMargin)
} }
val actualSQL = maybeSQL.get
try { try {
assert(actualSQL === expectedSQL) assert(generatedSQL === expectedSQL)
} catch { } catch {
case cause: Throwable => case cause: Throwable =>
fail( fail(
...@@ -65,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { ...@@ -65,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin) """.stripMargin)
} }
checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) checkAnswer(sqlContext.sql(generatedSQL), new DataFrame(sqlContext, plan))
} }
protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
......
...@@ -412,21 +412,22 @@ abstract class HiveComparisonTest ...@@ -412,21 +412,22 @@ abstract class HiveComparisonTest
originalQuery originalQuery
} else { } else {
numTotalQueries += 1 numTotalQueries += 1
new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => try {
val sql = new SQLBuilder(originalQuery.analyzed, TestHive).toSQL
numConvertibleQueries += 1 numConvertibleQueries += 1
logInfo( logInfo(
s""" s"""
|### Running SQL generation round-trip test {{{ |### Running SQL generation round-trip test {{{
|${originalQuery.analyzed.treeString} |${originalQuery.analyzed.treeString}
|Original SQL: |Original SQL:
|$queryString |$queryString
| |
|Generated SQL: |Generated SQL:
|$sql |$sql
|}}} |}}}
""".stripMargin.trim) """.stripMargin.trim)
new TestHive.QueryExecution(sql) new TestHive.QueryExecution(sql)
}.getOrElse { } catch { case NonFatal(e) =>
logInfo( logInfo(
s""" s"""
|### Cannot convert the following logical plan back to SQL {{{ |### Cannot convert the following logical plan back to SQL {{{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment