From acdf45fb52e29a0308cccdbef0ec0dca0815d300 Mon Sep 17 00:00:00 2001 From: Jose Torres <joseph.torres@databricks.com> Date: Wed, 6 Sep 2017 11:19:46 -0700 Subject: [PATCH] [SPARK-21765] Check that optimization doesn't affect isStreaming bit. ## What changes were proposed in this pull request? Add an assert in logical plan optimization that the isStreaming bit stays the same, and fix empty relation rules where that wasn't happening. ## How was this patch tested? new and existing unit tests Author: Jose Torres <joseph.torres@databricks.com> Author: Jose Torres <joseph-torres@databricks.com> Closes #19056 from joseph-torres/SPARK-21765-followup. --- .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../optimizer/PropagateEmptyRelation.scala | 11 ++- .../plans/logical/LocalRelation.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 1 - .../PropagateEmptyRelationSuite.scala | 44 ++++++++++++ .../execution/streaming/StreamExecution.scala | 2 +- .../sql/execution/streaming/socket.scala | 17 ++--- .../streaming/TextSocketStreamSuite.scala | 72 ++++++++++--------- 8 files changed, 103 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d7e5906f67..02d6778c08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -724,8 +724,10 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { case Filter(Literal(true, BooleanType), child) => child // If the filter condition always evaluate to null or false, // replace the input with an empty relation. - case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty) - case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty) + case Filter(Literal(null, _), child) => + LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming) + case Filter(Literal(false, BooleanType), child) => + LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming) // If any deterministic condition is guaranteed to be true given the constraints on the child's // output, remove the condition case f @ Filter(fc, p: LogicalPlan) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 987cd7434b..cfffa6bc2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -38,7 +38,8 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty) + private def empty(plan: LogicalPlan) = + LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming) def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => @@ -65,11 +66,15 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _: RepartitionByExpression => empty(p) // An aggregate with non-empty group expression will return one output row per group when the // input to the aggregate is not empty. If the input to the aggregate is empty then all groups - // will be empty and thus the output will be empty. + // will be empty and thus the output will be empty. If we're working on batch data, we can + // then treat the aggregate as redundant. + // + // If the aggregate is over streaming data, we may need to update the state store even if no + // new rows are processed, so we can't eliminate the node. // // If the grouping expressions are empty, however, then the aggregate will always produce a // single output row and thus we cannot propagate the EmptyRelation. - case Aggregate(ge, _, _) if ge.nonEmpty => empty(p) + case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p) // Generators like Hive-style UDTF may return their records within `close`. case Generate(_: Explode, _, _, _, _, _) => empty(p) case _ => p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 7a21183664..d73d7e73f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -58,7 +58,7 @@ case class LocalRelation(output: Seq[Attribute], * query. */ override final def newInstance(): this.type = { - LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] + LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type] } override protected def stringArgs: Iterator[Any] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 85b368c862..0e89d1c8f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -63,7 +63,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] - /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 2285be1693..bc1c48b99c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.StructType class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -124,6 +126,48 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("propagate empty streaming relation through multiple UnaryNode") { + val output = Seq('a.int) + val data = Seq(Row(1)) + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val relation = LocalRelation( + output, + data.map(converter(_).asInstanceOf[InternalRow]), + isStreaming = true) + + val query = relation + .where(false) + .select('a) + .where('a > 1) + .where('a != 200) + .orderBy('a.asc) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = LocalRelation(output, isStreaming = true) + + comparePlans(optimized, correctAnswer) + } + + test("don't propagate empty streaming relation through agg") { + val output = Seq('a.int) + val data = Seq(Row(1)) + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val relation = LocalRelation( + output, + data.map(converter(_).asInstanceOf[InternalRow]), + isStreaming = true) + + val query = relation + .groupBy('a)('a) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + test("don't propagate non-empty local relation") { val query = testRelation1 .where(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c224f2f9f1..71088ff638 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -659,7 +659,7 @@ class StreamExecution( replacements ++= output.zip(newPlan.output) newPlan }.getOrElse { - LocalRelation(output) + LocalRelation(output, isStreaming = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 8e63207959..0b22cbc46e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -29,8 +29,10 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String object TextSocketSource { @@ -126,17 +128,10 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - import sqlContext.implicits._ - val rawBatch = sqlContext.createDataset(rawList) - - // Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp - // if requested. - if (includeTimestamp) { - rawBatch.toDF("value", "timestamp") - } else { - // Strip out timestamp - rawBatch.select("_1").toDF("value") - } + val rdd = sqlContext.sparkContext + .parallelize(rawList) + .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } override def commit(end: Offset): Unit = synchronized { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 9ebf4d2835..ec11549073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -65,20 +65,22 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + assert(batch1.as[String].collect().toSeq === Seq("hello")) + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + assert(batch2.as[String].collect().toSeq === Seq("world")) - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) + val both = source.getBatch(None, offset2) + assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) // Try stopping the source to make sure this does not block forever. source.stop() @@ -104,22 +106,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq + assert(batch1Seq.map(_._1) === Seq("hello")) + val batch1Stamp = batch1Seq(0)._2 + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq + assert(batch2Seq.map(_._1) === Seq("world")) + val batch2Stamp = batch2Seq(0)._2 + assert(!batch2Stamp.before(batch1Stamp)) } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) // Try stopping the source to make sure this does not block forever. source.stop() @@ -184,12 +188,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val batch = source.getBatch(None, source.getOffset.get).as[String] + batch.collect() + val numRowsMetric = + batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + assert(numRowsMetric.nonEmpty) + assert(numRowsMetric.get.value === 1) + } source.stop() source = null } -- GitLab