diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 081aa03002cc61075ea1f5a3ca934a000ab636f0..cbcccb11f14ae2b4545de4ef989a635b1996adb9 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
+  test("Datasets agg type-inference") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+        |  val numeric = implicitly[Numeric[N]]
+        |  override def zero: N = numeric.zero
+        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+        |  override def finish(reduction: N): N = reduction
+        |}
+        |
+        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+        |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 6de3dd626576a4215e175a635f712d9c91f9bcfd..263f049104762d4ca0770b214b536af39d239450 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
     reduce(f.call _)
   }
 
-  /**
-   * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
-   * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
-   *
-   * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
-   *
-   * {{{
-   *   // Selects the age of the oldest employee and the aggregate expense for each department
-   *
-   *   // Scala:
-   *   import org.apache.spark.sql.functions._
-   *   df.groupBy("department").agg(max("age"), sum("expense"))
-   *
-   *   // Java:
-   *   import static org.apache.spark.sql.functions.*;
-   *   df.groupBy("department").agg(max("age"), sum("expense"));
-   * }}}
-   *
-   * We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
-   *
-   * @since 1.6.0
-   */
+  // This is here to prevent us from adding overloads that would be ambiguous.
   @scala.annotation.varargs
-  def agg(expr: Column, exprs: Column*): DataFrame =
-    groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
+  private def agg(exprs: Column*): DataFrame =
+    groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
 
   private def withEncoder(c: Column): Column = c match {
     case tc: TypedColumn[_, _] =>
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index ce40dd856f679383dba2dc4521bc88d7b70829fd..f7249b8945c490cdf58e415bcab263e20608136b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -404,11 +404,9 @@ public class JavaDatasetSuite implements Serializable {
       grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
     Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
 
-    Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
-      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
-      expr("sum(_2)"),
-      count("*"))
-      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
+    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
     Assert.assertEquals(
       Arrays.asList(
         new Tuple4<>("a", 3, 3L, 2L),