diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index fb90799534497e3870faadbcaa0b0bc944192e23..792fb3e795ed123a60a0fc7e17b711fcb88b7125 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
   override val output: Seq[Attribute] = range.output
 
   override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
 
   // output attributes should not affect the results
   override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
@@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
 
   protected override def doProduce(ctx: CodegenContext): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
+    val numGenerated = metricTerm(ctx, "numGeneratedRows")
 
     val initTerm = ctx.freshName("initRange")
     ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
-    val partitionEnd = ctx.freshName("partitionEnd")
-    ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
     val number = ctx.freshName("number")
     ctx.addMutableState("long", number, s"$number = 0L;")
-    val overflow = ctx.freshName("overflow")
-    ctx.addMutableState("boolean", overflow, s"$overflow = false;")
 
     val value = ctx.freshName("value")
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
-    val checkEnd = if (step > 0) {
-      s"$number < $partitionEnd"
-    } else {
-      s"$number > $partitionEnd"
-    }
+
+    // In order to periodically update the metrics without inflicting performance penalty, this
+    // operator produces elements in batches. After a batch is complete, the metrics are updated
+    // and a new batch is started.
+    // In the implementation below, the code in the inner loop is producing all the values
+    // within a batch, while the code in the outer loop is setting batch parameters and updating
+    // the metrics.
+
+    // Once number == batchEnd, it's time to progress to the next batch.
+    val batchEnd = ctx.freshName("batchEnd")
+    ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;")
+
+    // How many values should still be generated by this range operator.
+    val numElementsTodo = ctx.freshName("numElementsTodo")
+    ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;")
+
+    // How many values should be generated in the next batch.
+    val nextBatchTodo = ctx.freshName("nextBatchTodo")
+
+    // The default size of a batch.
+    val batchSize = 1000L
 
     ctx.addNewFunction("initRange",
       s"""
@@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
         |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
         |   $BigInt step = $BigInt.valueOf(${step}L);
         |   $BigInt start = $BigInt.valueOf(${start}L);
+        |   long partitionEnd;
         |
         |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
         |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
@@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
         |   } else {
         |     $number = st.longValue();
         |   }
+        |   $batchEnd = $number;
         |
         |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
         |     .multiply(step).add(start);
         |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $partitionEnd = Long.MAX_VALUE;
+        |     partitionEnd = Long.MAX_VALUE;
         |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $partitionEnd = Long.MIN_VALUE;
+        |     partitionEnd = Long.MIN_VALUE;
         |   } else {
-        |     $partitionEnd = end.longValue();
+        |     partitionEnd = end.longValue();
         |   }
         |
-        |   $numOutput.add(($partitionEnd - $number) / ${step}L);
+        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
+        |     $BigInt.valueOf($number));
+        |   $numElementsTodo  = startToEnd.divide(step).longValue();
+        |   if ($numElementsTodo < 0) {
+        |     $numElementsTodo = 0;
+        |   } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
+        |     $numElementsTodo++;
+        |   }
         | }
        """.stripMargin)
 
@@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
       |   initRange(partitionIndex);
       | }
       |
-      | while (!$overflow && $checkEnd) {
-      |  long $value = $number;
-      |  $number += ${step}L;
-      |  if ($number < $value ^ ${step}L < 0) {
-      |    $overflow = true;
-      |  }
-      |  ${consume(ctx, Seq(ev))}
-      |  if (shouldStop()) return;
+      | while (true) {
+      |   while ($number != $batchEnd) {
+      |     long $value = $number;
+      |     $number += ${step}L;
+      |     ${consume(ctx, Seq(ev))}
+      |     if (shouldStop()) return;
+      |   }
+      |
+      |   long $nextBatchTodo;
+      |   if ($numElementsTodo > ${batchSize}L) {
+      |     $nextBatchTodo = ${batchSize}L;
+      |     $numElementsTodo -= ${batchSize}L;
+      |   } else {
+      |     $nextBatchTodo = $numElementsTodo;
+      |     $numElementsTodo = 0;
+      |     if ($nextBatchTodo == 0) break;
+      |   }
+      |   $numOutput.add($nextBatchTodo);
+      |   $numGenerated.add($nextBatchTodo);
+      |
+      |   $batchEnd += $nextBatchTodo * ${step}L;
       | }
      """.stripMargin
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val numGeneratedRows = longMetric("numGeneratedRows")
     sqlContext
       .sparkContext
       .parallelize(0 until numSlices, numSlices)
@@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
             }
 
             numOutputRows += 1
+            numGeneratedRows += 1
             unsafeRow.setLong(0, ret)
             unsafeRow
           }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6d2d776c92b377da8ae706b0dc5fe357cd07bd81
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.math.abs
+import scala.util.Random
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+  test("SPARK-7150 range api") {
+    // numSlice is greater than length
+    val res1 = spark.range(0, 10, 1, 15).select("id")
+    assert(res1.count == 10)
+    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res2 = spark.range(3, 15, 3, 2).select("id")
+    assert(res2.count == 4)
+    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    val res3 = spark.range(1, -2).select("id")
+    assert(res3.count == 0)
+
+    // start is positive, end is negative, step is negative
+    val res4 = spark.range(1, -2, -2, 6).select("id")
+    assert(res4.count == 2)
+    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+    // start, end, step are negative
+    val res5 = spark.range(-3, -8, -2, 1).select("id")
+    assert(res5.count == 3)
+    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+    // start, end are negative, step is positive
+    val res6 = spark.range(-8, -4, 2, 1).select("id")
+    assert(res6.count == 2)
+    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+    val res7 = spark.range(-10, -9, -20, 1).select("id")
+    assert(res7.count == 0)
+
+    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+    assert(res8.count == 3)
+    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+    assert(res9.count == 2)
+    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+    // only end provided as argument
+    val res10 = spark.range(10).select("id")
+    assert(res10.count == 10)
+    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res11 = spark.range(-1).select("id")
+    assert(res11.count == 0)
+
+    // using the default slice number
+    val res12 = spark.range(3, 15, 3).select("id")
+    assert(res12.count == 4)
+    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    // difference between range start and end does not fit in a 64-bit integer
+    val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
+    val res13 = spark.range(-n, n, n / 9).select("id")
+    assert(res13.count == 18)
+  }
+
+  test("Range with randomized parameters") {
+    val MAX_NUM_STEPS = 10L * 1000
+
+    val seed = System.currentTimeMillis()
+    val random = new Random(seed)
+
+    def randomBound(): Long = {
+      val n = if (random.nextBoolean()) {
+        random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
+      } else {
+        random.nextLong() / 2
+      }
+      if (random.nextBoolean()) n else -n
+    }
+
+    for (l <- 1 to 10) {
+      val start = randomBound()
+      val end = randomBound()
+      val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1
+      val stepAbs = (abs(end - start) / numSteps) + 1
+      val step = if (start < end) stepAbs else -stepAbs
+      val partitions = random.nextInt(20) + 1
+
+      val expCount = (start until end by step).size
+      val expSum = (start until end by step).sum
+
+      for (codegen <- List(false, true)) {
+        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+          val res = spark.range(start, end, step, partitions).toDF("id").
+            agg(count("id"), sum("id")).collect()
+
+          withClue(s"seed = $seed start = $start end = $end step = $step partitions = " +
+              s"$partitions codegen = $codegen") {
+            assert(!res.isEmpty)
+            assert(res.head.getLong(0) == expCount)
+            if (expCount > 0) {
+              assert(res.head.getLong(1) == expSum)
+            }
+          }
+        }
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6a190b98ea983fda4b4f71bb61d8a232b2a841bf..e6338ab7cd800db70e0b8dd2a81e08136a89b1de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
   }
 
-  test("SPARK-7150 range api") {
-    // numSlice is greater than length
-    val res1 = spark.range(0, 10, 1, 15).select("id")
-    assert(res1.count == 10)
-    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res2 = spark.range(3, 15, 3, 2).select("id")
-    assert(res2.count == 4)
-    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-
-    val res3 = spark.range(1, -2).select("id")
-    assert(res3.count == 0)
-
-    // start is positive, end is negative, step is negative
-    val res4 = spark.range(1, -2, -2, 6).select("id")
-    assert(res4.count == 2)
-    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
-
-    // start, end, step are negative
-    val res5 = spark.range(-3, -8, -2, 1).select("id")
-    assert(res5.count == 3)
-    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
-
-    // start, end are negative, step is positive
-    val res6 = spark.range(-8, -4, 2, 1).select("id")
-    assert(res6.count == 2)
-    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
-
-    val res7 = spark.range(-10, -9, -20, 1).select("id")
-    assert(res7.count == 0)
-
-    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
-    assert(res8.count == 3)
-    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
-    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
-    assert(res9.count == 2)
-    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
-
-    // only end provided as argument
-    val res10 = spark.range(10).select("id")
-    assert(res10.count == 10)
-    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res11 = spark.range(-1).select("id")
-    assert(res11.count == 0)
-
-    // using the default slice number
-    val res12 = spark.range(3, 15, 3).select("id")
-    assert(res12.count == 4)
-    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-  }
-
   test("SPARK-8621: support empty string column name") {
     val df = Seq(Tuple1(1)).toDF("").as("t")
     // We should allow empty string as column name
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ddd7a03e8003a81bc8a743423aca30c91b5280a6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.File
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
+
+  test("Range query input/output/generated metrics") {
+    val numRows = 150L
+    val numSelectedRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
+      filter(x => x < numSelectedRows).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === numSelectedRows :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with repartitioning") {
+    val numRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(
+      spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === numRows)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === 20 :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with more repartitioning") {
+    withTempDir { tempDir =>
+      val dir = new File(tempDir, "pqS").getCanonicalPath
+
+      spark.range(10).write.parquet(dir)
+      spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+      val res = MetricsTestHelper.runAndGetMetrics(
+        spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
+            .toDF()
+      )
+
+      assert(res.recordsRead.sum == 10)
+      assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
+      assert(res.generatedRows == 30 :: Nil)
+      assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
+    }
+  }
+}
+
+object MetricsTestHelper {
+  case class AggregatedMetricsResult(
+      recordsRead: List[Long],
+      shuffleRecordsRead: List[Long],
+      generatedRows: List[Long],
+      outputRows: List[Long])
+
+  private[this] def extractMetricValues(
+      df: DataFrame,
+      metricValues: Map[Long, String],
+      metricName: String): List[Long] = {
+    df.queryExecution.executedPlan.collect {
+      case plan if plan.metrics.contains(metricName) =>
+        metricValues(plan.metrics(metricName).id).toLong
+    }.toList.sorted
+  }
+
+  def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
+      AggregatedMetricsResult = {
+    val spark = df.sparkSession
+    val sparkContext = spark.sparkContext
+
+    var recordsRead = List[Long]()
+    var shuffleRecordsRead = List[Long]()
+    val listener = new SparkListener() {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+        if (taskEnd.taskMetrics != null) {
+          recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
+            recordsRead
+          shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
+            shuffleRecordsRead
+        }
+      }
+    }
+
+    val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+
+    val prevUseWholeStageCodeGen =
+      spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
+    try {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+      sparkContext.addSparkListener(listener)
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+    } finally {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
+    }
+
+    val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
+    val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
+    val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
+    val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
+
+    AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 039625421e4ad07afdf20c6a05902a84a020bdf4..14fbe9f44313654bdea90b5032d0469820c95ad6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite
     }.getMessage
     assert(e2.contains("User specified schema not supported with `jdbc`"))
   }
+
+  test("Input/generated/output metrics on JDBC") {
+    val foobarCnt = spark.table("foobar").count()
+    val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
+    assert(res.recordsRead === foobarCnt :: Nil)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows.isEmpty)
+    assert(res.outputRows === foobarCnt :: Nil)
+  }
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index ec620c2403e3c8fe7ec992a9aa11524a80bb8ace..35c41b531c36541e5e5e641649d7c23ff34f432b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.hive.test.TestHive
 
 /**
@@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
   createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
 
   createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
+
+  test("Test input/generated/output metrics") {
+    import TestHive._
+
+    val episodesCnt = sql("select * from episodes").count()
+    val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
+    assert(episodesRes.recordsRead === episodesCnt :: Nil)
+    assert(episodesRes.shuffleRecordsRead.sum === 0)
+    assert(episodesRes.generatedRows.isEmpty)
+    assert(episodesRes.outputRows === episodesCnt :: Nil)
+
+    val serdeinsCnt = sql("select * from serdeins").count()
+    val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
+    assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
+    assert(serdeinsRes.shuffleRecordsRead.sum === 0)
+    assert(serdeinsRes.generatedRows.isEmpty)
+    assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+  }
 }