From 3b6ac323b16f8f6d79ee7bac6e7a57f841897d96 Mon Sep 17 00:00:00 2001 From: Burak Yavuz <brkyvz@gmail.com> Date: Mon, 9 Jan 2017 15:17:59 -0800 Subject: [PATCH] [SPARK-18952][BACKPORT] Regex strings not properly escaped in codegen for aggregations ## What changes were proposed in this pull request? Backport for #16361 to 2.1 branch. ## How was this patch tested? Unit tests Author: Burak Yavuz <brkyvz@gmail.com> Closes #16518 from brkyvz/reg-break-2.1. --- .../aggregate/RowBasedHashMapGenerator.scala | 12 +++++++----- .../aggregate/VectorizedHashMapGenerator.scala | 12 +++++++----- .../apache/spark/sql/DataFrameAggregateSuite.scala | 9 +++++++++ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index a77e178546..1b6e6d2f65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -43,28 +43,30 @@ class RowBasedHashMapGenerator( extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { - protected def initializeAggregateHashMap(): String = { + override protected def initializeAggregateHashMap(): String = { val generatedKeySchema: String = s"new org.apache.spark.sql.types.StructType()" + groupingKeySchema.map { key => + val keyName = ctx.addReferenceObj(key.name) key.dataType match { case d: DecimalType => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") val generatedValueSchema: String = s"new org.apache.spark.sql.types.StructType()" + bufferSchema.map { key => + val keyName = ctx.addReferenceObj(key.name) key.dataType match { case d: DecimalType => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7418df90b8..586328a6ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -48,28 +48,30 @@ class VectorizedHashMapGenerator( extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { - protected def initializeAggregateHashMap(): String = { + override protected def initializeAggregateHashMap(): String = { val generatedSchema: String = s"new org.apache.spark.sql.types.StructType()" + (groupingKeySchema ++ bufferSchema).map { key => + val keyName = ctx.addReferenceObj(key.name) key.dataType match { case d: DecimalType => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") val generatedAggBufferSchema: String = s"new org.apache.spark.sql.types.StructType()" + bufferSchema.map { key => + val keyName = ctx.addReferenceObj(key.name) key.dataType match { case d: DecimalType => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 645175900f..7853b22fec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") { + val df = Seq(("some[thing]", "random-string")).toDF("key", "val") + + checkAnswer( + df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(), + Row("some", 1) :: Nil + ) + } + test("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"), -- GitLab