From ec2b807212e568c9e98cd80746bcb61e02c7a98e Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Wed, 11 Nov 2015 10:52:23 -0800
Subject: [PATCH] [SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder

We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the  TypeTag-based approach.
We should keep only the compose-based way to create tuple encoder.

This PR also move `Encoder` to `org.apache.spark.sql`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9567 from cloud-fan/java.
---
 .../sql/{catalyst/encoders => }/Encoder.scala | 65 ++--------------
 .../catalyst/encoders/ExpressionEncoder.scala | 10 +--
 .../spark/sql/catalyst/encoders/package.scala |  3 +-
 .../plans/logical/basicOperators.scala        |  1 +
 .../scala/org/apache/spark/sql/Column.scala   |  2 +-
 .../org/apache/spark/sql/DataFrame.scala      |  2 -
 .../org/apache/spark/sql/GroupedDataset.scala |  2 +-
 .../org/apache/spark/sql/SQLContext.scala     |  2 +-
 .../aggregate/TypedAggregateExpression.scala  |  3 +-
 .../spark/sql/expressions/Aggregator.scala    |  3 +-
 .../org/apache/spark/sql/functions.scala      |  2 +-
 .../apache/spark/sql/JavaDatasetSuite.java    | 78 ++++++++++---------
 .../spark/sql/DatasetAggregatorSuite.scala    |  4 +-
 .../org/apache/spark/sql/QueryTest.scala      |  1 -
 14 files changed, 65 insertions(+), 113 deletions(-)
 rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst/encoders => }/Encoder.scala (71%)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
similarity index 71%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6569b900fe..1ff7340557 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -15,13 +15,14 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql
 
-import scala.reflect.ClassTag
-
-import org.apache.spark.util.Utils
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
 
 /**
  * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -38,9 +39,7 @@ trait Encoder[T] extends Serializable {
   def clsTag: ClassTag[T]
 }
 
-object Encoder {
-  import scala.reflect.runtime.universe._
-
+object Encoders {
   def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
   def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
   def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
@@ -129,54 +128,4 @@ object Encoder {
       constructExpression,
       ClassTag.apply(cls))
   }
-
-  def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
-
-  private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
-    import scala.reflect.api
-
-    // val mirror = runtimeMirror(c.getClassLoader)
-    val mirror = rootMirror
-    val sym = mirror.staticClass(c.getName)
-    val tpe = sym.selfType
-    TypeTag(mirror, new api.TypeCreator {
-      def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
-        if (m eq mirror) tpe.asInstanceOf[U # Type]
-        else throw new IllegalArgumentException(
-          s"Type tag defined in $mirror cannot be migrated to other mirrors.")
-    })
-  }
-
-  def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
-    implicit val typeTag1 = getTypeTag(c1)
-    implicit val typeTag2 = getTypeTag(c2)
-    ExpressionEncoder[(T1, T2)]()
-  }
-
-  def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
-    implicit val typeTag1 = getTypeTag(c1)
-    implicit val typeTag2 = getTypeTag(c2)
-    implicit val typeTag3 = getTypeTag(c3)
-    ExpressionEncoder[(T1, T2, T3)]()
-  }
-
-  def forTuple[T1, T2, T3, T4](
-      c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
-    implicit val typeTag1 = getTypeTag(c1)
-    implicit val typeTag2 = getTypeTag(c2)
-    implicit val typeTag3 = getTypeTag(c3)
-    implicit val typeTag4 = getTypeTag(c4)
-    ExpressionEncoder[(T1, T2, T3, T4)]()
-  }
-
-  def forTuple[T1, T2, T3, T4, T5](
-      c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
-    : Encoder[(T1, T2, T3, T4, T5)] = {
-    implicit val typeTag1 = getTypeTag(c1)
-    implicit val typeTag2 = getTypeTag(c2)
-    implicit val typeTag3 = getTypeTag(c3)
-    implicit val typeTag4 = getTypeTag(c4)
-    implicit val typeTag5 = getTypeTag(c5)
-    ExpressionEncoder[(T1, T2, T3, T4, T5)]()
-  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 005c0627f5..294afde534 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,18 +17,18 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.util.Utils
-
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
 
 /**
  * A factory for constructing encoders that convert objects and primitves to and from the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index d4642a5006..2c35adca9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.catalyst
 
+import org.apache.spark.sql.Encoder
+
 package object encoders {
   private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
     case e: ExpressionEncoder[A] => e
     case _ => sys.error(s"Only expression encoders are supported today")
   }
 }
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 764f8aaebd..597f03e752 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d26b6c3579..f0f275e91f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.DataTypeParser
 import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 691b476fff..a492099b93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -23,7 +23,6 @@ import java.util.Properties
 import scala.language.implicitConversions
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.TypeTag
-import scala.util.control.NonFatal
 
 import com.fasterxml.jackson.core.JsonFactory
 import org.apache.commons.lang3.StringUtils
@@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.encoders.Encoder
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index db61499229..61e2a95450 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
 import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1cf1e30f96..cd1fdc4edb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.SQLConf.SQLConfEntry
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index b5a87c56e6..dfcbac8687 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate
 import scala.language.existentials
 
 import org.apache.spark.Logging
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 2aa5a7d540..360c9a5bc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.expressions
 
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index a59d738010..ab49ed4b5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -26,7 +26,7 @@ import scala.util.Try
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 2da63d1b96..33d8388f61 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -30,8 +30,8 @@ import org.apache.spark.Accumulator;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.catalyst.encoders.Encoder;
-import org.apache.spark.sql.catalyst.encoders.Encoder$;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.GroupedDataset;
 import org.apache.spark.sql.test.TestSQLContext;
@@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*;
 public class JavaDatasetSuite implements Serializable {
   private transient JavaSparkContext jsc;
   private transient TestSQLContext context;
-  private transient Encoder$ e = Encoder$.MODULE$;
 
   @Before
   public void setUp() {
@@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testCollect() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
     List<String> collected = ds.collectAsList();
     Assert.assertEquals(Arrays.asList("hello", "world"), collected);
   }
@@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testTake() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
     List<String> collected = ds.takeAsList(1);
     Assert.assertEquals(Arrays.asList("hello"), collected);
   }
@@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testCommonOperation() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
     Assert.assertEquals("hello", ds.first());
 
     Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable {
       public Integer call(String v) throws Exception {
         return v.length();
       }
-    }, e.INT());
+    }, Encoders.INT());
     Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
 
     Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable {
         }
         return ls;
       }
-    }, e.STRING());
+    }, Encoders.STRING());
     Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
 
     Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
@@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable {
         }
         return ls;
       }
-    }, e.STRING());
+    }, Encoders.STRING());
     Assert.assertEquals(
       Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
       flatMapped.collectAsList());
@@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable {
   public void testForeach() {
     final Accumulator<Integer> accum = jsc.accumulator(0);
     List<String> data = Arrays.asList("a", "b", "c");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
 
     ds.foreach(new ForeachFunction<String>() {
       @Override
@@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testReduce() {
     List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> ds = context.createDataset(data, e.INT());
+    Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
 
     int reduced = ds.reduce(new ReduceFunction<Integer>() {
       @Override
@@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testGroupBy() {
     List<String> data = Arrays.asList("a", "foo", "bar");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
     GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
       @Override
       public Integer call(String v) throws Exception {
         return v.length();
       }
-    }, e.INT());
+    }, Encoders.INT());
 
     Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
       @Override
@@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable {
         }
         return sb.toString();
       }
-    }, e.STRING());
+    }, Encoders.STRING());
 
     Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
 
@@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable {
           return Collections.singletonList(sb.toString());
         }
       },
-      e.STRING());
+      Encoders.STRING());
 
     Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
 
     List<Integer> data2 = Arrays.asList(2, 6, 10);
-    Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
+    Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
     GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
       @Override
       public Integer call(Integer v) throws Exception {
         return v / 2;
       }
-    }, e.INT());
+    }, Encoders.INT());
 
     Dataset<String> cogrouped = grouped.cogroup(
       grouped2,
       new CoGroupFunction<Integer, String, Integer, String>() {
         @Override
         public Iterable<String> call(
-            Integer key,
-            Iterator<String> left,
-            Iterator<Integer> right) throws Exception {
+          Integer key,
+          Iterator<String> left,
+          Iterator<Integer> right) throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
           while (left.hasNext()) {
             sb.append(left.next());
@@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
           return Collections.singletonList(sb.toString());
         }
       },
-      e.STRING());
+      Encoders.STRING());
 
     Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
   }
@@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testGroupByColumn() {
     List<String> data = Arrays.asList("a", "foo", "bar");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
-    GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    GroupedDataset<Integer, String> grouped =
+      ds.groupBy(length(col("value"))).asKey(Encoders.INT());
 
     Dataset<String> mapped = grouped.map(
       new MapGroupFunction<Integer, String, String>() {
@@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable {
           return sb.toString();
         }
       },
-      e.STRING());
+      Encoders.STRING());
 
     Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
   }
@@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testSelect() {
     List<Integer> data = Arrays.asList(2, 6);
-    Dataset<Integer> ds = context.createDataset(data, e.INT());
+    Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
 
     Dataset<Tuple2<Integer, String>> selected = ds.select(
       expr("value + 1"),
-      col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
+      col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING()));
 
     Assert.assertEquals(
       Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
@@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testSetOperation() {
     List<String> data = Arrays.asList("abc", "abc", "xyz");
-    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
 
     Assert.assertEquals(
       Arrays.asList("abc", "xyz"),
       sort(ds.distinct().collectAsList().toArray(new String[0])));
 
     List<String> data2 = Arrays.asList("xyz", "foo", "foo");
-    Dataset<String> ds2 = context.createDataset(data2, e.STRING());
+    Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
 
     Dataset<String> intersected = ds.intersect(ds2);
     Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testJoin() {
     List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a");
+    Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
     List<Integer> data2 = Arrays.asList(2, 3, 4);
-    Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b");
+    Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
 
     Dataset<Tuple2<Integer, Integer>> joined =
       ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable {
 
   @Test
   public void testTupleEncoder() {
-    Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING());
+    Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
     List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
     Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
     Assert.assertEquals(data2, ds2.collectAsList());
 
-    Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING());
+    Encoder<Tuple3<Integer, Long, String>> encoder3 =
+      Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
     List<Tuple3<Integer, Long, String>> data3 =
       Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
     Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
     Assert.assertEquals(data3, ds3.collectAsList());
 
     Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
-      e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING());
+      Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
     List<Tuple4<Integer, String, Long, String>> data4 =
       Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
     Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
     Assert.assertEquals(data4, ds4.collectAsList());
 
     Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
-      e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN());
+      Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(),
+        Encoders.BOOLEAN());
     List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
       Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
     Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
@@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable {
   public void testNestedTupleEncoder() {
     // test ((int, string), string)
     Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
-      e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING());
+      Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
     List<Tuple2<Tuple2<Integer, String>, String>> data =
       Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
     Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
@@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable {
 
     // test (int, (string, string, long))
     Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
-      e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG()));
+      Encoders.tuple(Encoders.INT(),
+        Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG()));
     List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
       Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
     Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
@@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable {
 
     // test (int, ((string, long), string))
     Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
-      e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING()));
+      Encoders.tuple(Encoders.INT(),
+        Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING()));
     List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
       Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
     Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index d4f0ab76cf..378cd36527 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -17,13 +17,11 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.encoders.Encoder
-import org.apache.spark.sql.functions._
 
 import scala.language.postfixOps
 
 import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.expressions.Aggregator
 
 /** An `Aggregator` that adds up any numeric type returned by the given function. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 3c174efe73..7a8b7ae5bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.Encoder
 
 abstract class QueryTest extends PlanTest {
 
-- 
GitLab