diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 59af5b7095a77b620b8479ca4686ce8efc734c83..527d5b635a7f97eae20fef78bf99fd6dd3aa438b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ /** @@ -91,6 +91,7 @@ class Analyzer( ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: ResolveMissingReferences :: + ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: @@ -178,8 +179,8 @@ class Analyzer( case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)() - case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))() + case e: ExtractValue => Alias(e, toPrettySQL(e))() + case e => Alias(e, optionalAliasName.getOrElse(toPrettySQL(e)))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -1278,20 +1279,54 @@ class Analyzer( } /** - * Rewrites table generating expressions that either need one or more of the following in order - * to be resolved: - * - concrete attribute references for their output. - * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * operator under [[Project]]. * - * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions - * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an - * [[AnalysisException]] is throw. + * This rule will throw [[AnalysisException]] for following cases: + * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl` + * 2. more than one [[Generator]] is found in projectList, + * e.g. `SELECT explode(list), explode(list) FROM tbl` + * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]], + * e.g. `SELECT * FROM tbl SORT BY explode(list)` */ - object ResolveGenerate extends Rule[LogicalPlan] { + object ExtractGenerator extends Rule[LogicalPlan] { + private def hasGenerator(expr: Expression): Boolean = { + expr.find(_.isInstanceOf[Generator]).isDefined + } + + private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { + case UnresolvedAlias(_: Generator, _) => false + case Alias(_: Generator, _) => false + case MultiAlias(_: Generator, _) => false + case other => hasGenerator(other) + } + + private def trimAlias(expr: NamedExpression): Expression = expr match { + case UnresolvedAlias(child, _) => child + case Alias(child, _) => child + case MultiAlias(child, _) => child + case _ => expr + } + + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) + case _ => None + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case g: Generate if !g.resolved => - g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + case Project(projectList, _) if projectList.exists(hasNestedGenerator) => + val nestedGenerator = projectList.find(hasNestedGenerator).get + throw new AnalysisException("Generators are not supported when it's nested in " + + "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + + case Project(projectList, _) if projectList.count(hasGenerator) > 1 => + val generators = projectList.filter(hasGenerator).map(trimAlias) + throw new AnalysisException("Only one generator allowed per select clause but found " + + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. @@ -1299,11 +1334,9 @@ class Analyzer( val newProjectList = projectList.flatMap { case AliasedGenerator(generator, names) if generator.childrenResolved => - if (resolvedGenerator != null) { - failAnalysis( - s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + - s"and ${generator.nodeName} found.") - } + // It's a sanity check, this should not happen as the previous case will throw + // exception earlier. + assert(resolvedGenerator == null, "More than one generator found in SELECT.") resolvedGenerator = Generate( @@ -1311,7 +1344,7 @@ class Analyzer( join = projectList.size > 1, // Only join if there are other expressions in SELECT. outer = false, qualifier = None, - generatorOutput = makeGeneratorOutput(generator, names), + generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) resolvedGenerator.generatorOutput @@ -1323,44 +1356,50 @@ class Analyzer( } else { p } + + case g: Generate => g + + case p if p.expressions.exists(hasGenerator) => + throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + + "got: " + p.simpleString) } + } - /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ - private object AliasedGenerator { - def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => - // If not given the default names, and the TGF with multiple output columns - failAnalysis( - s"""Expect multiple names given for ${g.getClass.getName}, - |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) - case _ => None - } + /** + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. + */ + object ResolveGenerate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case g: Generate if !g.child.resolved || !g.generator.resolved => g + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) } /** * Construct the output attributes for a [[Generator]], given a list of names. If the list of * names is empty names are assigned from field names in generator. */ - private def makeGeneratorOutput( + private[sql] def makeGeneratorOutput( generator: Generator, names: Seq[String]): Seq[Attribute] = { - val elementTypes = generator.elementTypes + val elementAttrs = generator.elementSchema.toAttributes - if (names.length == elementTypes.length) { - names.zip(elementTypes).map { - case (name, (t, nullable, _)) => - AttributeReference(name, t, nullable)() + if (names.length == elementAttrs.length) { + names.zip(elementAttrs).map { + case (name, attr) => attr.withName(name) } } else if (names.isEmpty) { - elementTypes.map { - case (t, nullable, name) => AttributeReference(name, t, nullable)() - } + elementAttrs } else { failAnalysis( "The number of aliases supplied in the AS clause does not match the number of columns " + - s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"output by the UDTF expected ${elementAttrs.size} aliases but got " + s"${names.mkString(",")} ") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f82b63ad96764247f9f23bb47b3542ab5f9b8202..1f1897dc36df212d23a2f7dd1f3dfab035d09424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -142,8 +142,7 @@ object UnresolvedAttribute { case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expression]) extends Generator { - override def elementTypes: Seq[(DataType, Boolean, String)] = - throw new UnresolvedException(this, "elementTypes") + override def elementSchema: StructType = throw new UnresolvedException(this, "elementTypes") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 65d7a1d5a090443a18abf5e435b73a100c4edd98..12c35644e564c6ed05d7160641fb86a72dcbb9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -41,19 +41,16 @@ import org.apache.spark.sql.types._ */ trait Generator extends Expression { - // TODO ideally we should return the type of ArrayType(StructType), - // however, we don't keep the output field names in the Generator. - override def dataType: DataType = throw new UnsupportedOperationException + override def dataType: DataType = ArrayType(elementSchema) override def foldable: Boolean = false override def nullable: Boolean = false /** - * The output element data types in structure of Seq[(DataType, Nullable)] - * TODO we probably need to add more information like metadata etc. + * The output element schema. */ - def elementTypes: Seq[(DataType, Boolean, String)] + def elementSchema: StructType /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: InternalRow): TraversableOnce[InternalRow] @@ -69,7 +66,7 @@ trait Generator extends Expression { * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - elementTypes: Seq[(DataType, Boolean, String)], + elementSchema: StructType, function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator with CodegenFallback { @@ -117,10 +114,10 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) - override def elementTypes: Seq[(DataType, Boolean, String)] = child.dataType match { - case ArrayType(et, containsNull) => (et, containsNull, "col") :: Nil + override def elementSchema: StructType = child.dataType match { + case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull) case MapType(kt, vt, valueContainsNull) => - (kt, false, "key") :: (vt, valueContainsNull, "value") :: Nil + new StructType().add("key", kt, false).add("value", vt, valueContainsNull) } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index ecd09b7083f2e6bc6f0a88542a751b2132e51018..c14a2fb122618613f9f49fd406411ffa590a0ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -26,7 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -356,9 +356,9 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { - case (_, idx) => (StringType, true, s"c$idx") - } + override def elementSchema: StructType = StructType(fieldExpressions.zipWithIndex.map { + case (_, idx) => StructField(s"c$idx", StringType, nullable = true) + }) override def prettyName: String = "json_tuple" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7b4615db0661db7892d7b3813e710fd36af5029a..8b438e40e6af91a2b032a04538b74e329d08efc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -85,7 +85,7 @@ case class Generate( override lazy val resolved: Boolean = { generator.resolved && childrenResolved && - generator.elementTypes.length == generatorOutput.length && + generator.elementSchema.length == generatorOutput.length && generatorOutput.forall(_.resolved) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index f879b34358a9bb75200d4a2ff78643b94e89c5cb..3d2a624ba3b308407177e183c072f3dcac6a2ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -153,6 +153,8 @@ package object util { "`" + name.replace("`", "``") + "`" } + def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + /** * Returns the string representation of this expression that is safe to be put in * code comments of generated code. The length is capped at 128 characters. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 2e88f61d491cd486efca1cd50d261a507fe8dbbc..a41383fbf65625735a275e769bbb94f58da6f76c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -330,6 +329,25 @@ class AnalysisErrorSuite extends AnalysisTest { "The start time" :: "must be greater than or equal to 0." :: Nil ) + errorTest( + "generator nested in expressions", + listRelation.select(Explode('list) + 1), + "Generators are not supported when it's nested in expressions, but got: (explode(list) + 1)" + :: Nil + ) + + errorTest( + "generator appears in operator which is not Project", + listRelation.sortBy(Explode('list).asc), + "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil + ) + + errorTest( + "more than one generators in SELECT", + listRelation.select(Explode('list), Explode('list)), + "Only one generator allowed per select clause but found 2: explode(list), explode(list)" :: Nil + ) + test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1bea72c4711f18a4da5a7ac7018c13b975dac934..31dd64e909bb8fcbfbac544454f6faab9a080c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1578,16 +1578,13 @@ class Dataset[T] private[sql]( */ @Experimental def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { - val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val elementTypes = schema.toAttributes.map { - attr => (attr.dataType, attr.nullable, attr.name) } - val names = schema.toAttributes.map(_.name) - val convert = CatalystTypeConverters.createToCatalystConverter(schema) + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) val rowFunction = f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) withPlan { Generate(generator, join = true, outer = false, @@ -1614,13 +1611,13 @@ class Dataset[T] private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? - val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } + val elementSchema = attributes.toStructType def rowFunction(row: Row): TraversableOnce[InternalRow] = { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } - val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) withPlan { Generate(generator, join = true, outer = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 934bc38dc47cbaedb7471277eaee8b5ca52e834e..8b62c5507c0c818e190262b5320e28d33830a46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -66,7 +66,7 @@ case class GenerateExec( // boundGenerator.terminate() should be triggered after all of the rows in the partition val rows = if (join) { child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementTypes.size) + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) val joinedRow = new JoinedRow iter.flatMap { row => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f023edbd96dbe62f438e48e9a2591d81194a50e2..3220f143aa23f0a5c0cc89217c78880b9d184bdb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -145,7 +145,7 @@ private[sql] class HiveSessionCatalog( udaf } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) - udtf.elementTypes // Force it to check input data types. + udtf.elementSchema // Force it to check input data types. udtf } else { throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 5aab4132bc4ce0b3a4c71c5f3a2f067eeb1133bc..c53675694f62018ca9f9f6bd1d67bf8a02315f84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -202,9 +202,10 @@ private[hive] case class HiveGenericUDTF( @transient protected lazy val collector = new UDTFCollector - override lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { - field => (inspectorToDataType(field.getFieldObjectInspector), true, field.getFieldName) - } + override lazy val elementSchema = StructType(outputInspector.getAllStructFieldRefs.asScala.map { + field => StructField(field.getFieldName, inspectorToDataType(field.getFieldObjectInspector), + nullable = true) + }) @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray