Skip to content
Snippets Groups Projects
Commit 9cc74f95 authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Reynold Xin
Browse files

[SPARK-16488] Fix codegen variable namespace collision in pmod and partitionBy

## What changes were proposed in this pull request?

This patch fixes a variable namespace collision bug in pmod and partitionBy

## How was this patch tested?

Regression test for one possible occurrence. A more general fix in `ExpressionEvalHelper.checkEvaluation` will be in a subsequent PR.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #14144 from sameeragarwal/codegen-bug.
parent e50efd53
No related branches found
No related tags found
No related merge requests found
......@@ -498,34 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val remainder = ctx.freshName("remainder")
dataType match {
case dt: DecimalType =>
val decimalAdd = "$plus"
s"""
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2);
${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2);
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2);
} else {
${ev.value} = r;
${ev.value} = $remainder;
}
"""
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
if (r < 0) {
${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2);
if ($remainder < 0) {
${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2);
} else {
${ev.value} = r;
${ev.value} = $remainder;
}
"""
case _ =>
s"""
${ctx.javaType(dataType)} r = $eval1 % $eval2;
if (r < 0) {
${ev.value} = (r + $eval2) % $eval2;
${ctx.javaType(dataType)} $remainder = $eval1 % $eval2;
if ($remainder < 0) {
${ev.value} = ($remainder + $eval2) % $eval2;
} else {
${ev.value} = r;
${ev.value} = $remainder;
}
"""
}
......
......@@ -449,6 +449,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
}
test("pmod with partitionBy") {
val spark = this.spark
import spark.implicits._
case class Test(a: Int, b: String)
val data = Seq((0, "a"), (1, "b"), (1, "a"))
spark.createDataset(data).createOrReplaceTempView("test")
sql("select * from test distribute by pmod(_1, 2)")
.write
.partitionBy("_2")
.mode("overwrite")
.parquet(dir)
}
private def testRead(
df: => DataFrame,
expectedResult: Seq[String],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment