diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 1812a1152cb48e514286bbeee5d9f3be1cc42223..c35e5638e9273859f2e93161164a9e42391dbe24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -78,7 +78,7 @@ case class GenerateExec(
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
-  val boundGenerator: Generator = BindReferences.bindReference(generator, child.output)
+  lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output)
 
   protected override def doExecute(): RDD[InternalRow] = {
     // boundGenerator.terminate() should be triggered after all of the rows in the partition
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index f796a4cb4a398f979c1471b4854d12a435a1d5b4..4345a70601c343eaf9dd07a4c5d088447da1b787 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
     )
   }
 
+  test("count distinct") {
+    val inputData = MemoryStream[(Int, Seq[Int])]
+
+    val aggregated =
+      inputData.toDF()
+        .select($"*", explode($"_2") as 'value)
+        .groupBy($"_1")
+        .agg(size(collect_set($"value")))
+        .as[(Int, Int)]
+
+    testStream(aggregated, Update)(
+      AddData(inputData, (1, Seq(1, 2))),
+      CheckLastBatch((1, 2))
+    )
+  }
+
   test("simple count, complete mode") {
     val inputData = MemoryStream[Int]