Skip to content
Snippets Groups Projects
Commit faabe69c authored by Burak Yavuz's avatar Burak Yavuz Committed by Josh Rosen
Browse files

[SPARK-18952] Regex strings not properly escaped in codegen for aggregations

## What changes were proposed in this pull request?

If I use the function regexp_extract, and then in my regex string, use `\`, i.e. escape character, this fails codegen, because the `\` character is not properly escaped when codegen'd.

Example stack trace:
```
/* 059 */     private int maxSteps = 2;
/* 060 */     private int numRows = 0;
/* 061 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add("date_format(window#325.start, yyyy-MM-dd HH:mm)", org.apache.spark.sql.types.DataTypes.StringType)
/* 062 */     .add("regexp_extract(source#310.description, ([a-zA-Z]+)\[.*, 1)", org.apache.spark.sql.types.DataTypes.StringType);
/* 063 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add("sum", org.apache.spark.sql.types.DataTypes.LongType);
/* 064 */     private Object emptyVBase;

...

org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 62, Column 58: Invalid escape sequence
	at org.codehaus.janino.Scanner.scanLiteralCharacter(Scanner.java:918)
	at org.codehaus.janino.Scanner.produce(Scanner.java:604)
	at org.codehaus.janino.Parser.peekRead(Parser.java:3239)
	at org.codehaus.janino.Parser.parseArguments(Parser.java:3055)
	at org.codehaus.janino.Parser.parseSelector(Parser.java:2914)
	at org.codehaus.janino.Parser.parseUnaryExpression(Parser.java:2617)
	at org.codehaus.janino.Parser.parseMultiplicativeExpression(Parser.java:2573)
	at org.codehaus.janino.Parser.parseAdditiveExpression(Parser.java:2552)
```

In the codegend expression, the literal should use `\\` instead of `\`

A similar problem was solved here: https://github.com/apache/spark/pull/15156.

## How was this patch tested?

Regression test in `DataFrameAggregationSuite`

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #16361 from brkyvz/reg-break.
parent 15c2bd01
No related branches found
No related tags found
No related merge requests found
......@@ -43,28 +43,30 @@ class RowBasedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
protected def initializeAggregateHashMap(): String = {
override protected def initializeAggregateHashMap(): String = {
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
......
......@@ -48,28 +48,30 @@ class VectorizedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
protected def initializeAggregateHashMap(): String = {
override protected def initializeAggregateHashMap(): String = {
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
......
......@@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
val df = Seq(("some[thing]", "random-string")).toDF("key", "val")
checkAnswer(
df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(),
Row("some", 1) :: Nil
)
}
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment