diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 3860c4bba9a99f29dfffc6d7abac43fa48e5b7d5..5c18558f9bde787d00954fe985ae9e24a9e94a0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
       override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
         new Iterator[(Int, UnsafeRow)] {
           private[this] var rowSize: Int = dIn.readInt()
+          if (rowSize == EOF) dIn.close()
 
           override def hasNext: Boolean = rowSize != EOF
 
@@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
             row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
             rowSize = dIn.readInt() // read the next row's size
             if (rowSize == EOF) { // We are returning the last row in this stream
+              dIn.close()
               val _rowTuple = rowTuple
               // Null these out so that the byte array can be garbage collected once the entire
               // iterator has been consumed
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 40b47ae18d648f1897d1ea26ea1a0e8d190d0b98..bd02c73a26acec8367731e05ba954ad2e8cc9caf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
@@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.types._
 
+
+/**
+ * used to test close InputStream in UnsafeRowSerializer
+ */
+class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) {
+  var closed: Boolean = false
+  override def close(): Unit = {
+    closed = true
+    super.close()
+  }
+}
+
 class UnsafeRowSerializerSuite extends SparkFunSuite {
 
   private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
@@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
       serializerStream.writeValue(unsafeRow)
     }
     serializerStream.close()
-    val deserializerIter = serializer.deserializeStream(
-      new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator
+    val input = new ClosableByteArrayInputStream(baos.toByteArray)
+    val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
     for (expectedRow <- unsafeRows) {
       val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2
       assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes)
@@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
       assert(expectedRow.getInt(1) === actualRow.getInt(1))
     }
     assert(!deserializerIter.hasNext)
+    assert(input.closed)
+  }
+
+  test("close empty input stream") {
+    val baos = new ByteArrayOutputStream()
+    val dout = new DataOutputStream(baos)
+    dout.writeInt(-1)  // EOF
+    dout.flush()
+    val input = new ClosableByteArrayInputStream(baos.toByteArray)
+    val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
+    val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
+    assert(!deserializerIter.hasNext)
+    assert(input.closed)
   }
 }