From c9db8eaa42387c03cde12c1d145a6f72872def71 Mon Sep 17 00:00:00 2001 From: Tarek Auel <tarek.auel@googlemail.com> Date: Mon, 20 Jul 2015 15:32:46 -0700 Subject: [PATCH] [SPARK-9159][SQL] codegen ascii, base64, unbase64 Jira: https://issues.apache.org/jira/browse/SPARK-9159 Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7542 from tarekauel/SPARK-9159 and squashes the following commits: 772e6bc [Tarek Auel] [SPARK-9159][SQL] codegen ascii, base64, unbase64 --- .../expressions/stringOperations.scala | 37 ++++++++++++++++--- .../expressions/StringExpressionsSuite.scala | 2 +- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index e42be85367..e660d499fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -742,8 +742,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { +case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -756,13 +755,25 @@ case class Ascii(child: Expression) 0 } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + val bytes = ctx.freshName("bytes") + s""" + byte[] $bytes = $child.getBytes(); + if ($bytes.length > 0) { + ${ev.primitive} = (int) $bytes[0]; + } else { + ${ev.primitive} = 0; + } + """}) + } } /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { +case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -772,19 +783,33 @@ case class Base64(child: Expression) org.apache.commons.codec.binary.Base64.encodeBase64( bytes.asInstanceOf[Array[Byte]])) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + s"""${ev.primitive} = UTF8String.fromBytes( + org.apache.commons.codec.binary.Base64.encodeBase64($child)); + """}) + } + } /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { +case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + s""" + ${ev.primitive} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); + """}) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index d5731229df..67d97cd30b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -290,7 +290,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) checkEvaluation(Base64(b), "", create_row(Array[Byte]())) checkEvaluation(Base64(b), null, create_row(null)) - checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef")) + checkEvaluation(Base64(Literal.create(null, BinaryType)), null, create_row("abdef")) checkEvaluation(UnBase64(a), null, create_row(null)) checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef")) -- GitLab