From 9c5612f4e197dec82a5eac9542896d6216a866b7 Mon Sep 17 00:00:00 2001 From: Cheng Hao <hao.cheng@intel.com> Date: Mon, 27 Jul 2015 23:02:23 -0700 Subject: [PATCH] [MINOR] [SQL] Support mutable expression unit test with codegen projection This is actually contains 3 minor issues: 1) Enable the unit test(codegen) for mutable expressions (FormatNumber, Regexp_Replace/Regexp_Extract) 2) Use the `PlatformDependent.copyMemory` instead of the `System.arrayCopy` Author: Cheng Hao <hao.cheng@intel.com> Closes #7566 from chenghao-intel/codegen_ut and squashes the following commits: 24f43ea [Cheng Hao] enable codegen for mutable expression & UTF8String performance --- .../expressions/stringOperations.scala | 1 - .../spark/sql/StringFunctionsSuite.scala | 34 ++++++++++++++----- .../apache/spark/unsafe/types/UTF8String.java | 32 ++++++++--------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 38b0fb37de..edfffbc01c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -777,7 +777,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) override def dataType: DataType = IntegerType - protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 0f9c986f64..8e0ea76d15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest { } test("string regex_replace / regex_extract") { - val df = Seq(("100-200", "")).toDF("a", "b") + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") checkAnswer( df.select( regexp_replace($"a", "(\\d+)", "num"), regexp_extract($"a", "(\\d+)-(\\d+)", 1)), - Row("num-num", "100")) - - checkAnswer( - df.selectExpr( - "regexp_replace(a, '(\\d+)', 'num')", - "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), - Row("num-num", "200")) + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } test("string ascii function") { @@ -290,5 +298,15 @@ class StringFunctionsSuite extends QueryTest { df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable Row("5.0000")) } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 85381cf0ef..3e1cc67dbf 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -300,13 +300,13 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { } public UTF8String reverse() { - byte[] bytes = getBytes(); - byte[] result = new byte[bytes.length]; + byte[] result = new byte[this.numBytes]; int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - System.arraycopy(bytes, i, result, result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -316,11 +316,11 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { public UTF8String repeat(int times) { if (times <=0) { - return fromBytes(new byte[0]); + return EMPTY_UTF8; } byte[] newBytes = new byte[numBytes * times]; - System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -385,16 +385,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -421,15 +420,14 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { int offset = 0; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - System.arraycopy(getBytes(), 0, data, offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -454,9 +452,9 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; } @@ -494,7 +492,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); @@ -503,7 +501,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { j++; // Add separator if this is not the last input. if (j < numInputs) { - PlatformDependent.copyMemory( + copyMemory( separator.base, separator.offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, separator.numBytes); -- GitLab