diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 289c16aef47aa23963b0ba50611e1e66253c31db..63d87bfb6d24d6573453c192db7b2c0abd44c784 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -57,7 +57,9 @@ class TypedFilterOptimizationSuite extends PlanTest { comparePlans(optimized, expected) } - test("embed deserializer in filter condition if there is only one filter") { + // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules + // for typed filters. + ignore("embed deserializer in typed filter condition if there is only one filter") { val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 96c871d0343551a71c0b13f48de62703ced71cdf..6cbc27d91c1e852cef47f8ce8205040d24d3ae32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1944,11 +1944,11 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: T => Boolean): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[T => Boolean])) - val condition = Invoke(function, "apply", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** @@ -1961,11 +1961,11 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: FilterFunction[T]): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) - val condition = Invoke(function, "call", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "call", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8354a5bdac68f9db62352019776563a7ac45f315..37577accfda210b2a0c97d8c7f8702c918f04e24 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -92,6 +92,19 @@ public class JavaDatasetSuite implements Serializable { Assert.assertFalse(iter.hasNext()); } + // SPARK-15632: typed filter should preserve the underlying logical schema + @Test + public void testTypedFilterPreservingSchema() { + Dataset<Long> ds = spark.range(10); + Dataset<Long> ds2 = ds.filter(new FilterFunction<Long>() { + @Override + public boolean call(Long value) throws Exception { + return value > 3; + } + }); + Assert.assertEquals(ds.schema(), ds2.schema()); + } + @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bf2b0a2c7c1b7be7d796f8ff7770278f82a8f351..11b52bdead7b740b55218e9f95f1268f6c075e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -225,6 +225,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "b") } + test("SPARK-15632: typed filter should preserve the underlying logical schema") { + val ds = spark.range(10) + val ds2 = ds.filter(_ > 3) + assert(ds.schema.equals(ds2.schema)) + } + test("foreach") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.longAccumulator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 68f0ee864f47ffbe0c4e46e0f2a8e2990eb63850..f26e5e7b6990d7c3a4ad7ba07c184c1e01b4d08d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -97,7 +97,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) assert(ds.collect() === Array(0, 6)) }