Skip to content
Snippets Groups Projects
Commit 1c58fa90 authored by Eric Liang's avatar Eric Liang Committed by Reynold Xin
Browse files

[SPARK-16514][SQL] Fix various regex codegen bugs

## What changes were proposed in this pull request?

RegexExtract and RegexReplace currently crash on non-nullable input due use of a hard-coded local variable name (e.g. compiles fail with `java.lang.Exception: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 85, Column 26: Redefinition of local variable "m" `).

This changes those variables to use fresh names, and also in a few other places.

## How was this patch tested?

Unit tests. rxin

Author: Eric Liang <ekl@databricks.com>

Closes #14168 from ericl/sc-3906.
parent 56bd399a
No related branches found
No related tags found
No related merge requests found
...@@ -108,10 +108,11 @@ case class Like(left: Expression, right: Expression) ...@@ -108,10 +108,11 @@ case class Like(left: Expression, right: Expression)
""") """)
} }
} else { } else {
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => { nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s""" s"""
String rightStr = ${eval2}.toString(); String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); ${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
""" """
}) })
...@@ -157,10 +158,11 @@ case class RLike(left: Expression, right: Expression) ...@@ -157,10 +158,11 @@ case class RLike(left: Expression, right: Expression)
""") """)
} }
} else { } else {
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => { nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s""" s"""
String rightStr = ${eval2}.toString(); String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile(rightStr); ${patternClass} $pattern = ${patternClass}.compile($rightStr);
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
""" """
}) })
...@@ -259,6 +261,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ...@@ -259,6 +261,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
val classNamePattern = classOf[Pattern].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
val matcher = ctx.freshName("matcher")
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
...@@ -267,6 +271,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ...@@ -267,6 +271,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
ctx.addMutableState(classNameStringBuffer, ctx.addMutableState(classNameStringBuffer,
termResult, s"${termResult} = new $classNameStringBuffer();") termResult, s"${termResult} = new $classNameStringBuffer();")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s""" s"""
if (!$regexp.equals(${termLastRegex})) { if (!$regexp.equals(${termLastRegex})) {
...@@ -280,14 +290,14 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ...@@ -280,14 +290,14 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
} }
${termResult}.delete(0, ${termResult}.length()); ${termResult}.delete(0, ${termResult}.length());
java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
while (m.find()) { while (${matcher}.find()) {
m.appendReplacement(${termResult}, ${termLastReplacement}); ${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
} }
m.appendTail(${termResult}); ${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString()); ${ev.value} = UTF8String.fromString(${termResult}.toString());
${ev.isNull} = false; $setEvNotNull
""" """
}) })
} }
...@@ -334,10 +344,18 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ...@@ -334,10 +344,18 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
val termLastRegex = ctx.freshName("lastRegex") val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern") val termPattern = ctx.freshName("pattern")
val classNamePattern = classOf[Pattern].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s""" s"""
if (!$regexp.equals(${termLastRegex})) { if (!$regexp.equals(${termLastRegex})) {
...@@ -345,15 +363,15 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ...@@ -345,15 +363,15 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
${termLastRegex} = $regexp.clone(); ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
} }
java.util.regex.Matcher m = java.util.regex.Matcher ${matcher} =
${termPattern}.matcher($subject.toString()); ${termPattern}.matcher($subject.toString());
if (m.find()) { if (${matcher}.find()) {
java.util.regex.MatchResult mr = m.toMatchResult(); java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
${ev.value} = UTF8String.fromString(mr.group($idx)); ${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
${ev.isNull} = false; $setEvNotNull
} else { } else {
${ev.value} = UTF8String.EMPTY_UTF8; ${ev.value} = UTF8String.EMPTY_UTF8;
${ev.isNull} = false; $setEvNotNull
}""" }"""
}) })
} }
......
...@@ -631,6 +631,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -631,6 +631,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, null, row4) checkEvaluation(expr, null, row4)
checkEvaluation(expr, null, row5) checkEvaluation(expr, null, row5)
checkEvaluation(expr, null, row6) checkEvaluation(expr, null, row6)
val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num"))
checkEvaluation(nonNullExpr, "num-num", row1)
} }
test("RegexExtract") { test("RegexExtract") {
...@@ -657,6 +660,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -657,6 +660,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val expr1 = new RegExpExtract(s, p) val expr1 = new RegExpExtract(s, p)
checkEvaluation(expr1, "100", row1) checkEvaluation(expr1, "100", row1)
val nonNullExpr = RegExpExtract(Literal("100-200"), Literal("(\\d+)-(\\d+)"), Literal(1))
checkEvaluation(nonNullExpr, "100", row1)
} }
test("SPLIT") { test("SPLIT") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment