diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f240c30427b45ec22012dc658b8326676e8d6955..290de794dc3bbf7903e278dec9fdaf9b18d51e32 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -329,6 +329,11 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
 
+        // [SPARK-14451][SQL] Move encoder definition into Aggregator interface
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"),
+        ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"),
+        ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"),
+
         ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index c8b78bc14a39c2d77f5a536f41e3bd469be9d2b1..547da8f713ac77f97d8cc767495991bc881cdc55 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -285,7 +285,7 @@ class ReplSuite extends SparkFunSuite {
     val output = runInterpreter("local",
       """
         |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.{Encoder, Encoders}
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |val simpleSum = new Aggregator[Int, Int, Int] {
@@ -293,6 +293,8 @@ class ReplSuite extends SparkFunSuite {
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
         |  def finish(b: Int) = b                // Return the final result.
+        |  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+        |  def outputEncoder: Encoder[Int] = Encoders.scalaInt
         |}.toColumn
         |
         |val ds = Seq(1, 2, 3, 4).toDS()
@@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
-  test("Datasets agg type-inference") {
-    val output = runInterpreter("local",
-      """
-        |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
-        |import org.apache.spark.sql.expressions.Aggregator
-        |import org.apache.spark.sql.TypedColumn
-        |/** An `Aggregator` that adds up any numeric type returned by the given function. */
-        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
-        |  val numeric = implicitly[Numeric[N]]
-        |  override def zero: N = numeric.zero
-        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
-        |  override def finish(reduction: N): N = reduction
-        |}
-        |
-        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
-        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
-        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
-      """.stripMargin)
-    assertDoesNotContain("error:", output)
-    assertDoesNotContain("Exception", output)
-  }
-
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index dbfacba34637d27694d979a3155da041f7a03f94..7e10f15226b7abf77c73defc84f18ce36df411df 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite {
     val output = runInterpreter("local",
       """
         |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.{Encoder, Encoders}
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |val simpleSum = new Aggregator[Int, Int, Int] {
@@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite {
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
         |  def finish(b: Int) = b                // Return the final result.
+        |  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+        |  def outputEncoder: Encoder[Int] = Encoders.scalaInt
         |}.toColumn
         |
         |val ds = Seq(1, 2, 3, 4).toDS()
@@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
-  test("Datasets agg type-inference") {
-    val output = runInterpreter("local",
-      """
-        |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
-        |import org.apache.spark.sql.expressions.Aggregator
-        |import org.apache.spark.sql.TypedColumn
-        |/** An `Aggregator` that adds up any numeric type returned by the given function. */
-        |class SumOf[I, N : Numeric](f: I => N) extends
-        |  org.apache.spark.sql.expressions.Aggregator[I, N, N] {
-        |  val numeric = implicitly[Numeric[N]]
-        |  override def zero: N = numeric.zero
-        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
-        |  override def finish(reduction: N): N = reduction
-        |}
-        |
-        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
-        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
-        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
-      """.stripMargin)
-    assertDoesNotContain("error:", output)
-    assertDoesNotContain("Exception", output)
-  }
-
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 7a18d0afce6bafdd1130a9cdac54a9428bca2aa3..c39a78da6f9be7bbab752f090f7b181e81c542d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.{Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 
@@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
-class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
-  val numeric = implicitly[Numeric[OUT]]
-  override def zero: OUT = numeric.zero
-  override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
-  override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
-  override def finish(reduction: OUT): OUT = reduction
-
-  // TODO(ekl) java api support once this is exposed in scala
-}
-
-
 class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
   override def zero: Double = 0.0
   override def reduce(b: Double, a: IN): Double = b + f(a)
   override def merge(b1: Double, b2: Double): Double = b1 + b2
   override def finish(reduction: Double): Double = reduction
 
+  override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+  override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
-  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+
+  def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }
 
@@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
 
+  override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+  override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
-  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+
+  def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
   }
 }
 
@@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
 
+  override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+  override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
   // Java api support
   def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
-  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+  def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
   }
 }
 
@@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
     (b1._1 + b2._1, b1._2 + b2._2)
   }
 
+  override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
+  override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
-  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+  def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 9cb356f1ca375b146f6af3cf3bd4229975938b01..7da8379c9aa9aae740d5aa6b99498841553a3349 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
  *
  * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
  *
- * @tparam I The input type for the aggregation.
- * @tparam B The type of the intermediate value of the reduction.
- * @tparam O The type of the final output result.
+ * @tparam IN The input type for the aggregation.
+ * @tparam BUF The type of the intermediate value of the reduction.
+ * @tparam OUT The type of the final output result.
  * @since 1.6.0
  */
-abstract class Aggregator[-I, B, O] extends Serializable {
+abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
 
   /**
    * A zero value for this aggregation. Should satisfy the property that any b + zero = b.
    * @since 1.6.0
    */
-  def zero: B
+  def zero: BUF
 
   /**
    * Combine two values to produce a new value.  For performance, the function may modify `b` and
    * return it instead of constructing new object for b.
    * @since 1.6.0
    */
-  def reduce(b: B, a: I): B
+  def reduce(b: BUF, a: IN): BUF
 
   /**
    * Merge two intermediate values.
    * @since 1.6.0
    */
-  def merge(b1: B, b2: B): B
+  def merge(b1: BUF, b2: BUF): BUF
 
   /**
    * Transform the output of the reduction.
    * @since 1.6.0
    */
-  def finish(reduction: B): O
+  def finish(reduction: BUF): OUT
 
   /**
-   * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
+   * Specifies the [[Encoder]] for the intermediate value type.
+   * @since 2.0.0
+   */
+  def bufferEncoder: Encoder[BUF]
+
+  /**
+   * Specifies the [[Encoder]] for the final ouput value type.
+   * @since 2.0.0
+   */
+  def outputEncoder: Encoder[OUT]
+
+  /**
+   * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]].
    * operations.
    * @since 1.6.0
    */
-  def toColumn(
-      implicit bEncoder: Encoder[B],
-      cEncoder: Encoder[O]): TypedColumn[I, O] = {
+  def toColumn: TypedColumn[IN, OUT] = {
+    implicit val bEncoder = bufferEncoder
+    implicit val cEncoder = outputEncoder
+
     val expr =
       AggregateExpression(
         TypedAggregateExpression(this),
         Complete,
         isDistinct = false)
 
-    new TypedColumn[I, O](expr, encoderFor[O])
+    new TypedColumn[IN, OUT](expr, encoderFor[OUT])
   }
 }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index 8cb174b9067f34f6b816e09a25a2d36cbddf0e82..0e49f871de5c4e4c06a33f5e9b7c4a20899a7148 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -26,6 +26,7 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.KeyValueGroupedDataset;
 import org.apache.spark.sql.expressions.Aggregator;
@@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
   public void testTypedAggregationAnonClass() {
     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
 
-    Dataset<Tuple2<String, Integer>> agged =
-      grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+    Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
     Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
 
-    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
-      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
       .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
     Assert.assertEquals(
       Arrays.asList(
@@ -73,6 +72,16 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
     public Integer finish(Integer reduction) {
       return reduction;
     }
+
+    @Override
+    public Encoder<Integer> bufferEncoder() {
+      return Encoders.INT();
+    }
+
+    @Override
+    public Encoder<Integer> outputEncoder() {
+      return Encoders.INT();
+    }
   }
 
   @Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 08b3389ad918abcf32a368f48c1e1b56ae53c60f..3a7215ee39728345ad8215eb86dfe21ce1b1d6a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.language.postfixOps
 
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
@@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext
 
 
 object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
-
   override def zero: (Long, Long) = (0, 0)
-
   override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
     (countAndSum._1 + 1, countAndSum._2 + input._2)
   }
-
   override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
     (b1._1 + b2._1, b1._2 + b2._2)
   }
-
   override def finish(reduction: (Long, Long)): (Long, Long) = reduction
+  override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
+  override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
 }
 
+
 case class AggData(a: Int, b: String)
+
 object ClassInputAgg extends Aggregator[AggData, Int, Int] {
-  /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
   override def zero: Int = 0
-
-  /**
-   * Combine two values to produce a new value.  For performance, the function may modify `b` and
-   * return it instead of constructing new object for b.
-   */
   override def reduce(b: Int, a: AggData): Int = b + a.a
-
-  /**
-   * Transform the output of the reduction.
-   */
   override def finish(reduction: Int): Int = reduction
-
-  /**
-   * Merge two intermediate values
-   */
   override def merge(b1: Int, b2: Int): Int = b1 + b2
+  override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
 }
 
+
 object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
-  /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
   override def zero: (Int, AggData) = 0 -> AggData(0, "0")
-
-  /**
-   * Combine two values to produce a new value.  For performance, the function may modify `b` and
-   * return it instead of constructing new object for b.
-   */
   override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
-
-  /**
-   * Transform the output of the reduction.
-   */
   override def finish(reduction: (Int, AggData)): Int = reduction._1
-
-  /**
-   * Merge two intermediate values
-   */
   override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
     (b1._1 + b2._1, b1._2)
+  override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)]
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
 }
 
+
 object NameAgg extends Aggregator[AggData, String, String] {
   def zero: String = ""
-
   def reduce(b: String, a: AggData): String = a.b + b
-
   def merge(b1: String, b2: String): String = b1 + b2
-
   def finish(r: String): String = r
+  override def bufferEncoder: Encoder[String] = Encoders.STRING
+  override def outputEncoder: Encoder[String] = Encoders.STRING
+}
+
+
+class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
+  extends Aggregator[IN, OUT, OUT] {
+
+  private val numeric = implicitly[Numeric[OUT]]
+  override def zero: OUT = numeric.zero
+  override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+  override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+  override def finish(reduction: OUT): OUT = reduction
+  override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
+  override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
 }
 
+
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
   import testImplicits._
@@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
       ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
   }
 
+  test("generic typed sum") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+    checkDataset(
+      ds.groupByKey(_._1)
+        .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn),
+      ("a", 4.0), ("b", 3.0))
+
+    checkDataset(
+      ds.groupByKey(_._1)
+        .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn),
+      ("a", 4), ("b", 3))
+  }
+
   test("SPARK-12555 - result should not be corrupted after input columns are reordered") {
     val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]