diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index c3a2675ee5f45c3901d45979ed4704318caec4b4..09864e3f8392dbb56bd061bf8c7468e0b46b4438 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet
  * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first.
  * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size
  * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work.
+ * The difference between a [[KnownSizeEstimation]] and
+ * [[org.apache.spark.util.collection.SizeTracker]] is that, a
+ * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to
+ * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without
+ * using [[SizeEstimator]].
  */
-private[spark] trait SizeEstimation {
-  def estimatedSize: Option[Long]
+private[spark] trait KnownSizeEstimation {
+  def estimatedSize: Long
 }
 
 /**
@@ -209,18 +214,15 @@ object SizeEstimator extends Logging {
       // the size estimator since it references the whole REPL. Do nothing in this case. In
       // general all ClassLoaders and Classes will be shared between objects anyway.
     } else {
-      val estimatedSize = obj match {
-        case s: SizeEstimation => s.estimatedSize
-        case _ => None
-      }
-      if (estimatedSize.isDefined) {
-        state.size += estimatedSize.get
-      } else {
-        val classInfo = getClassInfo(cls)
-        state.size += alignSize(classInfo.shellSize)
-        for (field <- classInfo.pointerFields) {
-          state.enqueue(field.get(obj))
-        }
+      obj match {
+        case s: KnownSizeEstimation =>
+          state.size += s.estimatedSize
+        case _ =>
+          val classInfo = getClassInfo(cls)
+          state.size += alignSize(classInfo.shellSize)
+          for (field <- classInfo.pointerFields) {
+            state.enqueue(field.get(obj))
+          }
       }
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 9b6261af123e6f116feb3b3f70f175114f4df151..101610e38014ea4b5397f3e41da13ec30c09f60f 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) {
   @transient val hash32: Int = 0
 }
 
-class DummyClass8 extends SizeEstimation {
+class DummyClass8 extends KnownSizeEstimation {
   val x: Int = 0
 
-  override def estimatedSize: Option[Long] = Some(2015)
-}
-
-class DummyClass9 extends SizeEstimation {
-  val x: Int = 0
-
-  override def estimatedSize: Option[Long] = None
+  override def estimatedSize: Long = 2015
 }
 
 class SizeEstimatorSuite
@@ -231,9 +225,5 @@ class SizeEstimatorSuite
     // DummyClass8 provides its size estimation.
     assertResult(2015)(SizeEstimator.estimate(new DummyClass8))
     assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8)))
-
-    // DummyClass9 does not provide its size estimation.
-    assertResult(16)(SizeEstimator.estimate(new DummyClass9))
-    assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9)))
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 49ae09bf53782830e39312703cd0c41434fffc88..aebfea5832402c8c1fe9da7c647b297a5dacd1eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.unsafe.memory.MemoryLocation
-import org.apache.spark.util.{SizeEstimation, Utils}
+import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils}
 import org.apache.spark.util.collection.CompactBuffer
 import org.apache.spark.{SparkConf, SparkEnv}
 
@@ -190,7 +190,7 @@ private[execution] object HashedRelation {
 private[joins] final class UnsafeHashedRelation(
     private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
   extends HashedRelation
-  with SizeEstimation
+  with KnownSizeEstimation
   with Externalizable {
 
   private[joins] def this() = this(null)  // Needed for serialization
@@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation(
     }
   }
 
-  override def estimatedSize: Option[Long] = {
-    Option(binaryMap).map(_.getTotalMemoryConsumption)
+  override def estimatedSize: Long = {
+    if (binaryMap != null) {
+      binaryMap.getTotalMemoryConsumption
+    } else {
+      SizeEstimator.estimate(hashTable)
+    }
   }
 
   override def get(key: InternalRow): Seq[InternalRow] = {