Skip to content
Snippets Groups Projects
Commit acdf45fb authored by Jose Torres's avatar Jose Torres Committed by Tathagata Das
Browse files

[SPARK-21765] Check that optimization doesn't affect isStreaming bit.

## What changes were proposed in this pull request?

Add an assert in logical plan optimization that the isStreaming bit stays the same, and fix empty relation rules where that wasn't happening.

## How was this patch tested?

new and existing unit tests

Author: Jose Torres <joseph.torres@databricks.com>
Author: Jose Torres <joseph-torres@databricks.com>

Closes #19056 from joseph-torres/SPARK-21765-followup.
parent 36b48ee6
No related branches found
No related tags found
No related merge requests found
Showing with 103 additions and 52 deletions
......@@ -724,8 +724,10 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
case Filter(Literal(true, BooleanType), child) => child
// If the filter condition always evaluate to null or false,
// replace the input with an empty relation.
case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty)
case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty)
case Filter(Literal(null, _), child) =>
LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
case Filter(Literal(false, BooleanType), child) =>
LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
// If any deterministic condition is guaranteed to be true given the constraints on the child's
// output, remove the condition
case f @ Filter(fc, p: LogicalPlan) =>
......
......@@ -38,7 +38,8 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty)
private def empty(plan: LogicalPlan) =
LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
......@@ -65,11 +66,15 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _: RepartitionByExpression => empty(p)
// An aggregate with non-empty group expression will return one output row per group when the
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
// will be empty and thus the output will be empty.
// will be empty and thus the output will be empty. If we're working on batch data, we can
// then treat the aggregate as redundant.
//
// If the aggregate is over streaming data, we may need to update the state store even if no
// new rows are processed, so we can't eliminate the node.
//
// If the grouping expressions are empty, however, then the aggregate will always produce a
// single output row and thus we cannot propagate the EmptyRelation.
case Aggregate(ge, _, _) if ge.nonEmpty => empty(p)
case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
// Generators like Hive-style UDTF may return their records within `close`.
case Generate(_: Explode, _, _, _, _, _) => empty(p)
case _ => p
......
......@@ -58,7 +58,7 @@ case class LocalRelation(output: Seq[Attribute],
* query.
*/
override final def newInstance(): this.type = {
LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type]
}
override protected def stringArgs: Iterator[Any] = {
......
......@@ -63,7 +63,6 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected def batches: Seq[Batch]
/**
* Executes the batches of rules defined by the subclass. The batches are executed serially
* using the defined execution strategy. Within each batch, rules are also executed serially.
......
......@@ -18,11 +18,13 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StructType
class PropagateEmptyRelationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
......@@ -124,6 +126,48 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("propagate empty streaming relation through multiple UnaryNode") {
val output = Seq('a.int)
val data = Seq(Row(1))
val schema = StructType.fromAttributes(output)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val relation = LocalRelation(
output,
data.map(converter(_).asInstanceOf[InternalRow]),
isStreaming = true)
val query = relation
.where(false)
.select('a)
.where('a > 1)
.where('a != 200)
.orderBy('a.asc)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = LocalRelation(output, isStreaming = true)
comparePlans(optimized, correctAnswer)
}
test("don't propagate empty streaming relation through agg") {
val output = Seq('a.int)
val data = Seq(Row(1))
val schema = StructType.fromAttributes(output)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val relation = LocalRelation(
output,
data.map(converter(_).asInstanceOf[InternalRow]),
isStreaming = true)
val query = relation
.groupBy('a)('a)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = query.analyze
comparePlans(optimized, correctAnswer)
}
test("don't propagate non-empty local relation") {
val query = testRelation1
.where(true)
......
......@@ -659,7 +659,7 @@ class StreamExecution(
replacements ++= output.zip(newPlan.output)
newPlan
}.getOrElse {
LocalRelation(output)
LocalRelation(output, isStreaming = true)
}
}
......
......@@ -29,8 +29,10 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
object TextSocketSource {
......@@ -126,17 +128,10 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
batches.slice(sliceStart, sliceEnd)
}
import sqlContext.implicits._
val rawBatch = sqlContext.createDataset(rawList)
// Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp
// if requested.
if (includeTimestamp) {
rawBatch.toDF("value", "timestamp")
} else {
// Strip out timestamp
rawBatch.select("_1").toDF("value")
}
val rdd = sqlContext.sparkContext
.parallelize(rawList)
.map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
override def commit(end: Offset): Unit = synchronized {
......
......@@ -65,20 +65,22 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
val offset1 = source.getOffset.get
val batch1 = source.getBatch(None, offset1)
assert(batch1.as[String].collect().toSeq === Seq("hello"))
withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
val offset1 = source.getOffset.get
val batch1 = source.getBatch(None, offset1)
assert(batch1.as[String].collect().toSeq === Seq("hello"))
serverThread.enqueue("world")
while (source.getOffset.get === offset1) {
Thread.sleep(10)
}
val offset2 = source.getOffset.get
val batch2 = source.getBatch(Some(offset1), offset2)
assert(batch2.as[String].collect().toSeq === Seq("world"))
serverThread.enqueue("world")
while (source.getOffset.get === offset1) {
Thread.sleep(10)
val both = source.getBatch(None, offset2)
assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
}
val offset2 = source.getOffset.get
val batch2 = source.getBatch(Some(offset1), offset2)
assert(batch2.as[String].collect().toSeq === Seq("world"))
val both = source.getBatch(None, offset2)
assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
// Try stopping the source to make sure this does not block forever.
source.stop()
......@@ -104,22 +106,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
val offset1 = source.getOffset.get
val batch1 = source.getBatch(None, offset1)
val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
assert(batch1Seq.map(_._1) === Seq("hello"))
val batch1Stamp = batch1Seq(0)._2
serverThread.enqueue("world")
while (source.getOffset.get === offset1) {
Thread.sleep(10)
withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
val offset1 = source.getOffset.get
val batch1 = source.getBatch(None, offset1)
val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
assert(batch1Seq.map(_._1) === Seq("hello"))
val batch1Stamp = batch1Seq(0)._2
serverThread.enqueue("world")
while (source.getOffset.get === offset1) {
Thread.sleep(10)
}
val offset2 = source.getOffset.get
val batch2 = source.getBatch(Some(offset1), offset2)
val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
assert(batch2Seq.map(_._1) === Seq("world"))
val batch2Stamp = batch2Seq(0)._2
assert(!batch2Stamp.before(batch1Stamp))
}
val offset2 = source.getOffset.get
val batch2 = source.getBatch(Some(offset1), offset2)
val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
assert(batch2Seq.map(_._1) === Seq("world"))
val batch2Stamp = batch2Seq(0)._2
assert(!batch2Stamp.before(batch1Stamp))
// Try stopping the source to make sure this does not block forever.
source.stop()
......@@ -184,12 +188,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
val batch = source.getBatch(None, source.getOffset.get).as[String]
batch.collect()
val numRowsMetric =
batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
assert(numRowsMetric.nonEmpty)
assert(numRowsMetric.get.value === 1)
withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
val batch = source.getBatch(None, source.getOffset.get).as[String]
batch.collect()
val numRowsMetric =
batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
assert(numRowsMetric.nonEmpty)
assert(numRowsMetric.get.value === 1)
}
source.stop()
source = null
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment