Skip to content
Snippets Groups Projects
Commit 1ddd0f2f authored by Tarek Auel's avatar Tarek Auel Committed by Reynold Xin
Browse files

[SPARK-9161][SQL] codegen FormatNumber

Jira https://issues.apache.org/jira/browse/SPARK-9161

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7545 from tarekauel/SPARK-9161 and squashes the following commits:

21425c8 [Tarek Auel] [SPARK-9161][SQL] codegen FormatNumber
parent 228ab65a
No related branches found
No related tags found
No related merge requests found
......@@ -902,22 +902,15 @@ case class FormatNumber(x: Expression, d: Expression)
@transient
private val numberFormat: DecimalFormat = new DecimalFormat("")
override def eval(input: InternalRow): Any = {
val xObject = x.eval(input)
if (xObject == null) {
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}
val dObject = d.eval(input)
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
return null
}
val dValue = dObject.asInstanceOf[Int]
if (dValue != lastDValue) {
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length())
pattern.delete(0, pattern.length)
pattern.append("#,###,###,###,###,###,##0")
// decimal place
......@@ -930,9 +923,10 @@ case class FormatNumber(x: Expression, d: Expression)
pattern.append("0")
}
}
val dFormat = new DecimalFormat(pattern.toString())
lastDValue = dValue;
numberFormat.applyPattern(dFormat.toPattern())
val dFormat = new DecimalFormat(pattern.toString)
lastDValue = dValue
numberFormat.applyPattern(dFormat.toPattern)
}
x.dataType match {
......@@ -947,6 +941,52 @@ case class FormatNumber(x: Expression, d: Expression)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (num, d) => {
def typeHelper(p: String): String = {
x.dataType match {
case _ : DecimalType => s"""$p.toJavaBigDecimal()"""
case _ => s"$p"
}
}
val sb = classOf[StringBuffer].getName
val df = classOf[DecimalFormat].getName
val lastDValue = ctx.freshName("lastDValue")
val pattern = ctx.freshName("pattern")
val numberFormat = ctx.freshName("numberFormat")
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;")
ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""")
s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDValue) {
$pattern.append("#,###,###,###,###,###,##0");
if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
}
}
$df $dFormat = new $df($pattern.toString());
$lastDValue = $d;
$numberFormat.applyPattern($dFormat.toPattern());
${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
}
} else {
${ev.primitive} = null;
${ev.isNull} = true;
}
"""
})
}
override def prettyName: String = "format_number"
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment