diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 7b0c8ebdfa3b90f48ba3d1182ca7f97d0482cd1f..17eae88b49dec539b3542de7b0ce39e3b9f9588f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -77,9 +77,9 @@ package object debug {
   }
 
   /**
-   * Augments [[DataFrame]]s with debug methods.
+   * Augments [[Dataset]]s with debug methods.
    */
-  implicit class DebugQuery(query: DataFrame) extends Logging {
+  implicit class DebugQuery(query: Dataset[_]) extends Logging {
     def debug(): Unit = {
       val plan = query.queryExecution.executedPlan
       val visited = new collection.mutable.HashSet[TreeNodeRef]()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index c0fce4b96ac14eefa925f1e00da3214088cb9913..8aa0114d98d74e720c48da2f4b16dc4c2c5a15c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.debug
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData.TestData
 
 class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
 
@@ -26,6 +27,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
     testData.debug()
   }
 
+  test("Dataset.debug()") {
+    import testImplicits._
+    testData.as[TestData].debug()
+  }
+
   test("debugCodegen") {
     val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan)
     assert(res.contains("Subtree 1 / 2"))