Skip to content
Snippets Groups Projects
Commit c9db8eaa authored by Tarek Auel's avatar Tarek Auel Committed by Reynold Xin
Browse files

[SPARK-9159][SQL] codegen ascii, base64, unbase64

Jira: https://issues.apache.org/jira/browse/SPARK-9159

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7542 from tarekauel/SPARK-9159 and squashes the following commits:

772e6bc [Tarek Auel] [SPARK-9159][SQL] codegen ascii, base64, unbase64
parent 4863c11e
No related branches found
No related tags found
No related merge requests found
...@@ -742,8 +742,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres ...@@ -742,8 +742,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
/** /**
* Returns the numeric value of the first character of str. * Returns the numeric value of the first character of str.
*/ */
case class Ascii(child: Expression) case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def dataType: DataType = IntegerType override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType) override def inputTypes: Seq[DataType] = Seq(StringType)
...@@ -756,13 +755,25 @@ case class Ascii(child: Expression) ...@@ -756,13 +755,25 @@ case class Ascii(child: Expression)
0 0
} }
} }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (child) => {
val bytes = ctx.freshName("bytes")
s"""
byte[] $bytes = $child.getBytes();
if ($bytes.length > 0) {
${ev.primitive} = (int) $bytes[0];
} else {
${ev.primitive} = 0;
}
"""})
}
} }
/** /**
* Converts the argument from binary to a base 64 string. * Converts the argument from binary to a base 64 string.
*/ */
case class Base64(child: Expression) case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def dataType: DataType = StringType override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType) override def inputTypes: Seq[DataType] = Seq(BinaryType)
...@@ -772,19 +783,33 @@ case class Base64(child: Expression) ...@@ -772,19 +783,33 @@ case class Base64(child: Expression)
org.apache.commons.codec.binary.Base64.encodeBase64( org.apache.commons.codec.binary.Base64.encodeBase64(
bytes.asInstanceOf[Array[Byte]])) bytes.asInstanceOf[Array[Byte]]))
} }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (child) => {
s"""${ev.primitive} = UTF8String.fromBytes(
org.apache.commons.codec.binary.Base64.encodeBase64($child));
"""})
}
} }
/** /**
* Converts the argument from a base 64 string to BINARY. * Converts the argument from a base 64 string to BINARY.
*/ */
case class UnBase64(child: Expression) case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def dataType: DataType = BinaryType override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType) override def inputTypes: Seq[DataType] = Seq(StringType)
protected override def nullSafeEval(string: Any): Any = protected override def nullSafeEval(string: Any): Any =
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (child) => {
s"""
${ev.primitive} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString());
"""})
}
} }
/** /**
......
...@@ -290,7 +290,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -290,7 +290,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
checkEvaluation(Base64(b), "", create_row(Array[Byte]())) checkEvaluation(Base64(b), "", create_row(Array[Byte]()))
checkEvaluation(Base64(b), null, create_row(null)) checkEvaluation(Base64(b), null, create_row(null))
checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef")) checkEvaluation(Base64(Literal.create(null, BinaryType)), null, create_row("abdef"))
checkEvaluation(UnBase64(a), null, create_row(null)) checkEvaluation(UnBase64(a), null, create_row(null))
checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef")) checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment