From 3b6ac323b16f8f6d79ee7bac6e7a57f841897d96 Mon Sep 17 00:00:00 2001
From: Burak Yavuz <brkyvz@gmail.com>
Date: Mon, 9 Jan 2017 15:17:59 -0800
Subject: [PATCH] [SPARK-18952][BACKPORT] Regex strings not properly escaped in
 codegen for aggregations

## What changes were proposed in this pull request?

Backport for #16361 to 2.1 branch.

## How was this patch tested?

Unit tests

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #16518 from brkyvz/reg-break-2.1.
---
 .../aggregate/RowBasedHashMapGenerator.scala         | 12 +++++++-----
 .../aggregate/VectorizedHashMapGenerator.scala       | 12 +++++++-----
 .../apache/spark/sql/DataFrameAggregateSuite.scala   |  9 +++++++++
 3 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index a77e178546..1b6e6d2f65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -43,28 +43,30 @@ class RowBasedHashMapGenerator(
   extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
     groupingKeySchema, bufferSchema) {
 
-  protected def initializeAggregateHashMap(): String = {
+  override protected def initializeAggregateHashMap(): String = {
     val generatedKeySchema: String =
       s"new org.apache.spark.sql.types.StructType()" +
         groupingKeySchema.map { key =>
+          val keyName = ctx.addReferenceObj(key.name)
           key.dataType match {
             case d: DecimalType =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
                   |${d.precision}, ${d.scale}))""".stripMargin
             case _ =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
           }
         }.mkString("\n").concat(";")
 
     val generatedValueSchema: String =
       s"new org.apache.spark.sql.types.StructType()" +
         bufferSchema.map { key =>
+          val keyName = ctx.addReferenceObj(key.name)
           key.dataType match {
             case d: DecimalType =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
                   |${d.precision}, ${d.scale}))""".stripMargin
             case _ =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
           }
         }.mkString("\n").concat(";")
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 7418df90b8..586328a6ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -48,28 +48,30 @@ class VectorizedHashMapGenerator(
   extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
     groupingKeySchema, bufferSchema) {
 
-  protected def initializeAggregateHashMap(): String = {
+  override protected def initializeAggregateHashMap(): String = {
     val generatedSchema: String =
       s"new org.apache.spark.sql.types.StructType()" +
         (groupingKeySchema ++ bufferSchema).map { key =>
+          val keyName = ctx.addReferenceObj(key.name)
           key.dataType match {
             case d: DecimalType =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
                   |${d.precision}, ${d.scale}))""".stripMargin
             case _ =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
           }
         }.mkString("\n").concat(";")
 
     val generatedAggBufferSchema: String =
       s"new org.apache.spark.sql.types.StructType()" +
         bufferSchema.map { key =>
+          val keyName = ctx.addReferenceObj(key.name)
           key.dataType match {
             case d: DecimalType =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
                   |${d.precision}, ${d.scale}))""".stripMargin
             case _ =>
-              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+              s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
           }
         }.mkString("\n").concat(";")
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 645175900f..7853b22fec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
+    val df = Seq(("some[thing]", "random-string")).toDF("key", "val")
+
+    checkAnswer(
+      df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(),
+      Row("some", 1) :: Nil
+    )
+  }
+
   test("rollup") {
     checkAnswer(
       courseSales.rollup("course", "year").sum("earnings"),
-- 
GitLab