diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 9dcac33b4107c43034c7321ce7b23eec80ceadc4..ab690fd5fbbcaab5c51ee2a707a2b5bc6c829ab2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -136,7 +136,7 @@ case class FlatMapGroupsWithStateExec(
           outputIterator,
           {
             store.commit()
-            longMetric("numTotalStateRows") += store.numKeys()
+            setStoreMetrics(store)
           }
         )
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index a4e4ca821374ca9fd6705bf95b244b3df9f5808e..1887b07c49b730c721b07b0e1801ce08d6782823 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,18 +186,10 @@ trait ProgressReporter extends Logging {
     if (lastExecution == null) return Nil
     // lastExecution could belong to one of the previous triggers if `!hasNewData`.
     // Walking the plan again should be inexpensive.
-    val stateNodes = lastExecution.executedPlan.collect {
-      case p if p.isInstanceOf[StateStoreWriter] => p
-    }
-    stateNodes.map { node =>
-      val numRowsUpdated = if (hasNewData) {
-        node.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)
-      } else {
-        0L
-      }
-      new StateOperatorProgress(
-        numRowsTotal = node.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L),
-        numRowsUpdated = numRowsUpdated)
+    lastExecution.executedPlan.collect {
+      case p if p.isInstanceOf[StateStoreWriter] =>
+        val progress = p.asInstanceOf[StateStoreWriter].getProgress()
+        if (hasNewData) progress else progress.copy(newNumRowsUpdated = 0)
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index bae7a15165e4349023ba7574323611425f527bcf..fa4c99c01916fe447a99c50d60e6626d1356bd42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -35,7 +35,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.io.LZ4CompressionCodec
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SizeEstimator, Utils}
 
 
 /**
@@ -172,7 +172,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
       }
     }
 
-    override def numKeys(): Long = mapToUpdate.size()
+    override def metrics: StateStoreMetrics = {
+      StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty)
+    }
 
     /**
      * Whether all updates have been committed
@@ -230,6 +232,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     loadedMaps.values.foreach(_.clear())
   }
 
+  override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
+    Nil
+  }
+
   override def toString(): String = {
     s"HDFSStateStoreProvider[" +
       s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 86886466c4f56da23a12306520d78498a298d7cd..9da610e359f90d1b3ac8ec7f46f59ef7ef6fdbd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -94,8 +94,8 @@ trait StateStore {
 
   def iterator(): Iterator[UnsafeRowPair]
 
-  /** Number of keys in the state store */
-  def numKeys(): Long
+  /** Current metrics of the state store */
+  def metrics: StateStoreMetrics
 
   /**
    * Whether all updates have been committed
@@ -103,6 +103,24 @@ trait StateStore {
   def hasCommitted: Boolean
 }
 
+/**
+ * Metrics reported by a state store
+ * @param numKeys         Number of keys in the state store
+ * @param memoryUsedBytes Memory used by the state store
+ * @param customMetrics   Custom implementation-specific metrics
+ *                        The metrics reported through this must have the same `name` as those
+ *                        reported by `StateStoreProvider.customMetrics`.
+ */
+case class StateStoreMetrics(
+    numKeys: Long,
+    memoryUsedBytes: Long,
+    customMetrics: Map[StateStoreCustomMetric, Long])
+
+/**
+ * Name and description of custom implementation-specific metrics that a
+ * state store may wish to expose.
+ */
+case class StateStoreCustomMetric(name: String, desc: String)
 
 /**
  * Trait representing a provider that provide [[StateStore]] instances representing
@@ -158,22 +176,36 @@ trait StateStoreProvider {
 
   /** Optional method for providers to allow for background maintenance (e.g. compactions) */
   def doMaintenance(): Unit = { }
+
+  /**
+   * Optional custom metrics that the implementation may want to report.
+   * @note The StateStore objects created by this provider must report the same custom metrics
+   * (specifically, same names) through `StateStore.metrics`.
+   */
+  def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
 }
 
 object StateStoreProvider {
+
+  /**
+   * Return a instance of the given provider class name. The instance will not be initialized.
+   */
+  def create(providerClassName: String): StateStoreProvider = {
+    val providerClass = Utils.classForName(providerClassName)
+    providerClass.newInstance().asInstanceOf[StateStoreProvider]
+  }
+
   /**
-   * Return a provider instance of the given provider class.
-   * The instance will be already initialized.
+   * Return a instance of the required provider, initialized with the given configurations.
    */
-  def instantiate(
+  def createAndInit(
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
       indexOrdinal: Option[Int], // for sorting the data
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
-    val providerClass = Utils.classForName(storeConf.providerClass)
-    val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
+    val provider = create(storeConf.providerClass)
     provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
     provider
   }
@@ -298,7 +330,7 @@ object StateStore extends Logging {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
         storeProviderId,
-        StateStoreProvider.instantiate(
+        StateStoreProvider.createAndInit(
           storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
       )
       reportActiveStoreInstance(storeProviderId)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index c5722466a33af3f8b019181be9d0f8ad345008ae..77b1160a063fbc23f3cdd4e57b0a187b6208a341 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming
 import java.util.UUID
 import java.util.concurrent.TimeUnit._
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
@@ -29,9 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.{CompletionIterator, NextIterator}
 
@@ -73,8 +75,21 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
     "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
     "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"),
     "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"),
-    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes")
-  )
+    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
+    "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state")
+  ) ++ stateStoreCustomMetrics
+
+  /**
+   * Get the progress made by this stateful operator after execution. This should be called in
+   * the driver after this SparkPlan has been executed and metrics have been updated.
+   */
+  def getProgress(): StateOperatorProgress = {
+    new StateOperatorProgress(
+      numRowsTotal = longMetric("numTotalStateRows").value,
+      numRowsUpdated = longMetric("numUpdatedStateRows").value,
+      memoryUsedBytes = longMetric("stateMemory").value,
+      numPartitions = this.sqlContext.conf.numShufflePartitions)
+  }
 
   /** Records the duration of running `body` for the next query progress update. */
   protected def timeTakenMs(body: => Unit): Long = {
@@ -83,6 +98,26 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
     val endTime = System.nanoTime()
     math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
   }
+
+  /**
+   * Set the SQL metrics related to the state store.
+   * This should be called in that task after the store has been updated.
+   */
+  protected def setStoreMetrics(store: StateStore): Unit = {
+
+    val storeMetrics = store.metrics
+    longMetric("numTotalStateRows") += storeMetrics.numKeys
+    longMetric("stateMemory") += storeMetrics.memoryUsedBytes
+    storeMetrics.customMetrics.foreach { case (metric, value) =>
+      longMetric(metric.name) += value
+    }
+  }
+
+  private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
+    val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
+    provider.supportedCustomMetrics.map { m =>
+      m.name -> SQLMetrics.createTimingMetric(sparkContext, m.desc) }.toMap
+  }
 }
 
 /** An operator that supports watermark. */
@@ -197,7 +232,6 @@ case class StateStoreSaveExec(
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
         val numOutputRows = longMetric("numOutputRows")
-        val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
         val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
         val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
@@ -218,7 +252,7 @@ case class StateStoreSaveExec(
             commitTimeMs += timeTakenMs {
               store.commit()
             }
-            numTotalStateRows += store.numKeys()
+            setStoreMetrics(store)
             store.iterator().map { rowPair =>
               numOutputRows += 1
               rowPair.value
@@ -261,7 +295,7 @@ case class StateStoreSaveExec(
               override protected def close(): Unit = {
                 allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
                 commitTimeMs += timeTakenMs { store.commit() }
-                numTotalStateRows += store.numKeys()
+                setStoreMetrics(store)
               }
             }
 
@@ -285,7 +319,7 @@ case class StateStoreSaveExec(
                   // Remove old aggregates if watermark specified
                   allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
                   commitTimeMs += timeTakenMs { store.commit() }
-                  numTotalStateRows += store.numKeys()
+                  setStoreMetrics(store)
                   false
                 } else {
                   true
@@ -368,7 +402,7 @@ case class StreamingDeduplicateExec(
         allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
         allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
         commitTimeMs += timeTakenMs { store.commit() }
-        numTotalStateRows += store.numKeys()
+        setStoreMetrics(store)
       })
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index fb590e7df996b501e28696ef33ce012febc26c2b..81a2387b803968d38ead93dd7c025a3b798fd6d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -37,7 +37,10 @@ import org.apache.spark.annotation.InterfaceStability
 @InterfaceStability.Evolving
 class StateOperatorProgress private[sql](
     val numRowsTotal: Long,
-    val numRowsUpdated: Long) extends Serializable {
+    val numRowsUpdated: Long,
+    val memoryUsedBytes: Long,
+    val numPartitions: Long
+  ) extends Serializable {
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))
@@ -45,9 +48,14 @@ class StateOperatorProgress private[sql](
   /** The pretty (i.e. indented) JSON representation of this progress. */
   def prettyJson: String = pretty(render(jsonValue))
 
+  private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress =
+    new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, numPartitions)
+
   private[sql] def jsonValue: JValue = {
     ("numRowsTotal" -> JInt(numRowsTotal)) ~
-    ("numRowsUpdated" -> JInt(numRowsUpdated))
+    ("numRowsUpdated" -> JInt(numRowsUpdated)) ~
+    ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
+    ("numPartitions" -> JInt(numPartitions))
   }
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 7cb86dc14384425eee662bfa8372a18ce3047474..c843b65020d8c7dfda7638e1932ca9f658f358a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import java.io.{File, IOException}
 import java.net.URI
 import java.util.UUID
+import java.util.concurrent.ConcurrentHashMap
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -184,6 +185,15 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     }
   }
 
+  test("reports memory usage") {
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    val noDataMemoryUsed = store.metrics.memoryUsedBytes
+    put(store, "a", 1)
+    store.commit()
+    assert(store.metrics.memoryUsedBytes > noDataMemoryUsed)
+  }
+
   test("StateStore.get") {
     quietly {
       val dir = newDir()
@@ -554,12 +564,12 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     assert(!store.hasCommitted)
     assert(get(store, "a") === None)
     assert(store.iterator().isEmpty)
-    assert(store.numKeys() === 0)
+    assert(store.metrics.numKeys === 0)
 
     // Verify state after updating
     put(store, "a", 1)
     assert(get(store, "a") === Some(1))
-    assert(store.numKeys() === 1)
+    assert(store.metrics.numKeys === 1)
 
     assert(store.iterator().nonEmpty)
     assert(getLatestData(provider).isEmpty)
@@ -567,9 +577,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     // Make updates, commit and then verify state
     put(store, "b", 2)
     put(store, "aa", 3)
-    assert(store.numKeys() === 3)
+    assert(store.metrics.numKeys === 3)
     remove(store, _.startsWith("a"))
-    assert(store.numKeys() === 1)
+    assert(store.metrics.numKeys === 1)
     assert(store.commit() === 1)
 
     assert(store.hasCommitted)
@@ -587,9 +597,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     // New updates to the reloaded store with new version, and does not change old version
     val reloadedProvider = newStoreProvider(store.id)
     val reloadedStore = reloadedProvider.getStore(1)
-    assert(reloadedStore.numKeys() === 1)
+    assert(reloadedStore.metrics.numKeys === 1)
     put(reloadedStore, "c", 4)
-    assert(reloadedStore.numKeys() === 2)
+    assert(reloadedStore.metrics.numKeys === 2)
     assert(reloadedStore.commit() === 2)
     assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
     assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 0d9ca81349be5f26e1636b910ad79df74e63d914..9f2f0d195de9f2e01da20cd402beab621d27681e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
 import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -1077,7 +1077,7 @@ object FlatMapGroupsWithStateSuite {
     override def abort(): Unit = { }
     override def id: StateStoreId = null
     override def version: Long = 0
-    override def numKeys(): Long = map.size
+    override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
     override def hasCommitted: Boolean = true
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index 901cf34f289cc9401ef6cd453a9398652d764661..d3cafac4f17552aa619facdff102b2df2c69ffb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -33,16 +33,10 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._
 
 class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
-  implicit class EqualsIgnoreCRLF(source: String) {
-    def equalsIgnoreCRLF(target: String): Boolean = {
-      source.replaceAll("\r\n|\r|\n", System.lineSeparator) ===
-        target.replaceAll("\r\n|\r|\n", System.lineSeparator)
-    }
-  }
-
   test("StreamingQueryProgress - prettyJson") {
     val json1 = testProgress1.prettyJson
-    assert(json1.equalsIgnoreCRLF(
+    assertJson(
+      json1,
       s"""
         |{
         |  "id" : "${testProgress1.id.toString}",
@@ -62,7 +56,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
         |  },
         |  "stateOperators" : [ {
         |    "numRowsTotal" : 0,
-        |    "numRowsUpdated" : 1
+        |    "numRowsUpdated" : 1,
+        |    "memoryUsedBytes" : 2,
+        |    "numPartitions" : 4
         |  } ],
         |  "sources" : [ {
         |    "description" : "source",
@@ -75,13 +71,13 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
         |    "description" : "sink"
         |  }
         |}
-      """.stripMargin.trim))
+      """.stripMargin.trim)
     assert(compact(parse(json1)) === testProgress1.json)
 
     val json2 = testProgress2.prettyJson
-    assert(
-      json2.equalsIgnoreCRLF(
-        s"""
+    assertJson(
+      json2,
+      s"""
          |{
          |  "id" : "${testProgress2.id.toString}",
          |  "runId" : "${testProgress2.runId.toString}",
@@ -93,7 +89,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
          |  },
          |  "stateOperators" : [ {
          |    "numRowsTotal" : 0,
-         |    "numRowsUpdated" : 1
+         |    "numRowsUpdated" : 1,
+         |    "memoryUsedBytes" : 2,
+         |    "numPartitions" : 4
          |  } ],
          |  "sources" : [ {
          |    "description" : "source",
@@ -105,7 +103,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
          |    "description" : "sink"
          |  }
          |}
-      """.stripMargin.trim))
+      """.stripMargin.trim)
     assert(compact(parse(json2)) === testProgress2.json)
   }
 
@@ -121,14 +119,15 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
 
   test("StreamingQueryStatus - prettyJson") {
     val json = testStatus.prettyJson
-    assert(json.equalsIgnoreCRLF(
+    assertJson(
+      json,
       """
         |{
         |  "message" : "active",
         |  "isDataAvailable" : true,
         |  "isTriggerActive" : false
         |}
-      """.stripMargin.trim))
+      """.stripMargin.trim)
   }
 
   test("StreamingQueryStatus - json") {
@@ -209,6 +208,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
       }
     }
   }
+
+  def assertJson(source: String, expected: String): Unit = {
+    assert(
+      source.replaceAll("\r\n|\r|\n", System.lineSeparator) ===
+        expected.replaceAll("\r\n|\r|\n", System.lineSeparator))
+  }
 }
 
 object StreamingQueryStatusAndProgressSuite {
@@ -224,7 +229,8 @@ object StreamingQueryStatusAndProgressSuite {
       "min" -> "2016-12-05T20:54:20.827Z",
       "avg" -> "2016-12-05T20:54:20.827Z",
       "watermark" -> "2016-12-05T20:54:20.827Z").asJava),
-    stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
+    stateOperators = Array(new StateOperatorProgress(
+      numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)),
     sources = Array(
       new SourceProgress(
         description = "source",
@@ -247,7 +253,8 @@ object StreamingQueryStatusAndProgressSuite {
     durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava),
     // empty maps should be handled correctly
     eventTime = new java.util.HashMap(Map.empty[String, String].asJava),
-    stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
+    stateOperators = Array(new StateOperatorProgress(
+      numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)),
     sources = Array(
       new SourceProgress(
         description = "source",