From ec2b807212e568c9e98cd80746bcb61e02c7a98e Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Wed, 11 Nov 2015 10:52:23 -0800 Subject: [PATCH] [SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the TypeTag-based approach. We should keep only the compose-based way to create tuple encoder. This PR also move `Encoder` to `org.apache.spark.sql` Author: Wenchen Fan <wenchen@databricks.com> Closes #9567 from cloud-fan/java. --- .../sql/{catalyst/encoders => }/Encoder.scala | 65 ++-------------- .../catalyst/encoders/ExpressionEncoder.scala | 10 +-- .../spark/sql/catalyst/encoders/package.scala | 3 +- .../plans/logical/basicOperators.scala | 1 + .../scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 2 - .../org/apache/spark/sql/GroupedDataset.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 3 +- .../spark/sql/expressions/Aggregator.scala | 3 +- .../org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 78 ++++++++++--------- .../spark/sql/DatasetAggregatorSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 1 - 14 files changed, 65 insertions(+), 113 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst/encoders => }/Encoder.scala (71%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala similarity index 71% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 6569b900fe..1ff7340557 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.encoders +package org.apache.spark.sql -import scala.reflect.ClassTag - -import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -38,9 +39,7 @@ trait Encoder[T] extends Serializable { def clsTag: ClassTag[T] } -object Encoder { - import scala.reflect.runtime.universe._ - +object Encoders { def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) @@ -129,54 +128,4 @@ object Encoder { constructExpression, ClassTag.apply(cls)) } - - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] - - private def getTypeTag[T](c: Class[T]): TypeTag[T] = { - import scala.reflect.api - - // val mirror = runtimeMirror(c.getClassLoader) - val mirror = rootMirror - val sym = mirror.staticClass(c.getName) - val tpe = sym.selfType - TypeTag(mirror, new api.TypeCreator { - def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = - if (m eq mirror) tpe.asInstanceOf[U # Type] - else throw new IllegalArgumentException( - s"Type tag defined in $mirror cannot be migrated to other mirrors.") - }) - } - - def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - ExpressionEncoder[(T1, T2)]() - } - - def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - ExpressionEncoder[(T1, T2, T3)]() - } - - def forTuple[T1, T2, T3, T4]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - ExpressionEncoder[(T1, T2, T3, T4)]() - } - - def forTuple[T1, T2, T3, T4, T5]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) - : Encoder[(T1, T2, T3, T4, T5)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - implicit val typeTag5 = getTypeTag(c5) - ExpressionEncoder[(T1, T2, T3, T4, T5)]() - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 005c0627f5..294afde534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.util.Utils - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} +import org.apache.spark.util.Utils +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index d4642a5006..2c35adca9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.Encoder + package object encoders { private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e case _ => sys.error(s"Only expression encoders are supported today") } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 764f8aaebd..597f03e752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d26b6c3579..f0f275e91f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 691b476fff..a492099b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory import org.apache.commons.lang3.StringUtils @@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index db61499229..61e2a95450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1cf1e30f96..cd1fdc4edb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index b5a87c56e6..dfcbac8687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 2aa5a7d540..360c9a5bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a59d738010..ab49ed4b5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 2da63d1b96..33d8388f61 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -30,8 +30,8 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.catalyst.encoders.Encoder; -import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.test.TestSQLContext; @@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; private transient TestSQLContext context; - private transient Encoder$ e = Encoder$.MODULE$; @Before public void setUp() { @@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable { public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() { @@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() { @@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); @@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction<String>() { @Override @@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, e.INT()); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override @@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() { @Override @@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable { } return sb.toString(); } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); @@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List<Integer> data2 = Arrays.asList(2, 6, 10); - Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); + Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; } - }, e.INT()); + }, Encoders.INT()); Dataset<String> cogrouped = grouped.cogroup( grouped2, new CoGroupFunction<Integer, String, Integer, String>() { @Override public Iterable<String> call( - Integer key, - Iterator<String> left, - Iterator<Integer> right) throws Exception { + Integer key, + Iterator<String> left, + Iterator<Integer> right) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); } @@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, e.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset<Integer, String> grouped = + ds.groupBy(length(col("value"))).asKey(Encoders.INT()); Dataset<String> mapped = grouped.map( new MapGroupFunction<Integer, String, String>() { @@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable { return sb.toString(); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); } @@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); - Dataset<Integer> ds = context.createDataset(data, e.INT()); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), - col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), @@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset<String> ds2 = context.createDataset(data2, e.STRING()); + Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a"); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); - Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b"); + Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTupleEncoder() { - Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING()); + Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); - Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + Encoder<Tuple3<Integer, Long, String>> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a")); Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a")); Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = @@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable { public void testNestedTupleEncoder() { // test ((int, string), string) Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder = - e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); @@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, (string, string, long)) Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 = - e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = @@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, ((string, long), string)) Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 = - e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index d4f0ab76cf..378cd36527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders.Encoder -import org.apache.spark.sql.functions._ import scala.language.postfixOps import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator /** An `Aggregator` that adds up any numeric type returned by the given function. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174efe73..7a8b7ae5bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { -- GitLab