diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 96d2dfc2658b9966fc9dfc1478b0bc8377a00d1e..9262e938c2a6004e65316818f95bb9882588961f 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -473,21 +473,4 @@ class ReplSuite extends SparkFunSuite {
     assertDoesNotContain("AssertionError", output)
     assertDoesNotContain("Exception", output)
   }
-
-  test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
-    val resultValue = 12345
-    val output = runInterpreter("local",
-      s"""
-         |val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
-         |val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1))
-         |val broadcasted = sc.broadcast($resultValue)
-         |
-         |// Using broadcast triggers serialization issue in KeyValueGroupedDataset
-         |val dataset = mapGroups.map(_ => broadcasted.value)
-         |dataset.collect()
-      """.stripMargin)
-    assertDoesNotContain("error:", output)
-    assertDoesNotContain("Exception", output)
-    assertContains(s": Array[Int] = Array($resultValue, $resultValue)", output)
-  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 55f04878052aab114ded96848d27608a01a6a411..6fa7b0487732ebecd10e8a2e46defd25b992e46e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -923,6 +923,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
         .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
   }
 
+  test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
+    val resultValue = 12345
+    val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
+    val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1))
+    val broadcasted = spark.sparkContext.broadcast(resultValue)
+
+    // Using broadcast triggers serialization issue in KeyValueGroupedDataset
+    val dataset = mapGroups.map(_ => broadcasted.value)
+
+    assert(dataset.collect() sameElements Array(resultValue, resultValue))
+  }
+
   Seq(true, false).foreach { eager =>
     def testCheckpointing(testName: String)(f: => Unit): Unit = {
       test(s"Dataset.checkpoint() - $testName (eager = $eager)") {