Skip to content
Snippets Groups Projects
Commit 231f6724 authored by Yuming Wang's avatar Yuming Wang Committed by Herman van Hovell
Browse files

[SPARK-21205][SQL] pmod(number, 0) should be null.

## What changes were proposed in this pull request?
Hive `pmod(3.13, 0)`:
```:sql
hive> select pmod(3.13, 0);
OK
NULL
Time taken: 2.514 seconds, Fetched: 1 row(s)
hive>
```

Spark `mod(3.13, 0)`:
```:sql
spark-sql> select mod(3.13, 0);
NULL
spark-sql>
```

But the Spark `pmod(3.13, 0)`:
```:sql
spark-sql> select pmod(3.13, 0);
17/06/25 09:35:58 ERROR SparkSQLDriver: Failed in [select pmod(3.13, 0)]
java.lang.NullPointerException
	at org.apache.spark.sql.catalyst.expressions.Pmod.pmod(arithmetic.scala:504)
	at org.apache.spark.sql.catalyst.expressions.Pmod.nullSafeEval(arithmetic.scala:432)
	at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:419)
	at org.apache.spark.sql.catalyst.expressions.UnaryExpression.eval(Expression.scala:323)
...
```
This PR make `pmod(number, 0)` to null.

## How was this patch tested?
unit tests

Author: Yuming Wang <wgyumg@gmail.com>

Closes #18413 from wangyum/SPARK-21205.
parent 1347b2a6
No related branches found
No related tags found
No related merge requests found
......@@ -421,52 +421,101 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = NumericType
protected override def nullSafeEval(left: Any, right: Any) =
dataType match {
case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short])
case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float])
case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double])
case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal])
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
}
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val remainder = ctx.freshName("remainder")
dataType match {
case dt: DecimalType =>
val decimalAdd = "$plus"
s"""
${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2);
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2);
} else {
${ev.value} = $remainder;
}
"""
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2);
if ($remainder < 0) {
${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2);
} else {
${ev.value} = $remainder;
}
"""
case _ =>
s"""
${ctx.javaType(dataType)} $remainder = $eval1 % $eval2;
if ($remainder < 0) {
${ev.value} = ($remainder + $eval2) % $eval2;
} else {
${ev.value} = $remainder;
}
"""
}
})
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val isZero = if (dataType.isInstanceOf[DecimalType]) {
s"${eval2.value}.isZero()"
} else {
s"${eval2.value} == 0"
}
val remainder = ctx.freshName("remainder")
val javaType = ctx.javaType(dataType)
val result = dataType match {
case DecimalType.Fixed(_, _) =>
val decimalAdd = "$plus"
s"""
${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value});
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value});
} else {
${ev.value}=$remainder;
}
"""
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
${ctx.javaType(dataType)} $remainder =
(${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value});
if ($remainder < 0) {
${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value});
} else {
${ev.value}=$remainder;
}
"""
case _ =>
s"""
${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value};
if ($remainder < 0) {
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
} else {
${ev.value}=$remainder;
}
"""
}
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
if ($isZero) {
${ev.isNull} = true;
} else {
${eval1.code}
$result
}""")
} else {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
${eval1.code}
if (${eval1.isNull}) {
${ev.isNull} = true;
} else {
$result
}
}""")
}
}
private def pmod(a: Int, n: Int): Int = {
......@@ -501,7 +550,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
private def pmod(a: Decimal, n: Decimal): Decimal = {
val r = a % n
if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
if (r != null && r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
}
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
......
......@@ -214,7 +214,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Pmod(left, right), convert(1))
checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
checkEvaluation(Pmod(left, Literal(convert(0))), null) // mod by 0
}
checkEvaluation(Pmod(Literal(-7), Literal(3)), 2)
checkEvaluation(Pmod(Literal(7.2D), Literal(4.1D)), 3.1000000000000005)
......@@ -223,6 +223,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Pmod(positiveShort, negativeShort), positiveShort.toShort)
checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt)
checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong)
// mod by 0
checkEvaluation(Pmod(Literal(-7), Literal(0)), null)
checkEvaluation(Pmod(Literal(7.2D), Literal(0D)), null)
checkEvaluation(Pmod(Literal(7.2F), Literal(0F)), null)
checkEvaluation(Pmod(Literal(2.toByte), Literal(0.toByte)), null)
checkEvaluation(Pmod(positiveShort, 0.toShort), null)
}
test("function least") {
......
......@@ -92,3 +92,7 @@ select abs(-3.13), abs('-2.19');
-- positive/negative
select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11);
-- pmod
select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null);
select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint));
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 57
-- Number of queries: 59
-- !query 0
......@@ -468,3 +468,19 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11)
struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)>
-- !query 56 output
-1.11 -1.11 1.11 1.11
-- !query 57
select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null)
-- !query 57 schema
struct<pmod(-7, 2):int,pmod(0, 2):int,pmod(7, 0):int,pmod(7, CAST(NULL AS INT)):int,pmod(CAST(NULL AS INT), 2):int,pmod(CAST(NULL AS DOUBLE), CAST(NULL AS DOUBLE)):double>
-- !query 57 output
1 0 NULL NULL NULL NULL
-- !query 58
select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint))
-- !query 58 schema
struct<pmod(CAST(3.13 AS DECIMAL(10,0)), CAST(0 AS DECIMAL(10,0))):decimal(10,0),pmod(CAST(2 AS SMALLINT), CAST(0 AS SMALLINT)):smallint>
-- !query 58 output
NULL NULL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment