Skip to content
Snippets Groups Projects
Commit 53c16b92 authored by Reynold Xin's avatar Reynold Xin Committed by Michael Armbrust
Browse files

[SPARK-8362] [SQL] Add unit tests for +, -, *, /, %

Added unit tests for all supported data types for:
- Add
- Subtract
- Multiply
- Divide
- UnaryMinus
- Remainder

Fixed bugs caught by the unit tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #6813 from rxin/SPARK-8362 and squashes the following commits:

fb3fe62 [Reynold Xin] Added Remainder.
3b266ba [Reynold Xin] [SPARK-8362] Add unit tests for +, -, *, /.
parent 9073a426
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.util.TypeUtils
...@@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { ...@@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType) private lazy val numeric = TypeUtils.getNumeric(dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
} }
protected override def evalInternal(evalE: Any) = numeric.negate(evalE) protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
...@@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression { ...@@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide // byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType => case ByteType | ShortType =>
defineCodeGen(ctx, ev, (eval1, eval2) => defineCodeGen(ctx, ev,
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case _ => case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
} }
...@@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti ...@@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "/" override def symbol: String = "/"
override def decimalMethod: String = "$divide" override def decimalMethod: String = "$div"
override def nullable: Boolean = true override def nullable: Boolean = true
...@@ -245,11 +244,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ...@@ -245,11 +244,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else { } else {
s"${eval2.primitive} == 0" s"${eval2.primitive} == 0"
} }
val method = if (left.dataType.isInstanceOf[DecimalType]) { val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
s".$decimalMethod" val javaType = ctx.javaType(left.dataType)
} else {
s"$symbol"
}
eval1.code + eval2.code + eval1.code + eval2.code +
s""" s"""
boolean ${ev.isNull} = false; boolean ${ev.isNull} = false;
...@@ -257,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ...@@ -257,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
if (${eval1.isNull} || ${eval2.isNull} || $test) { if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true; ${ev.isNull} = true;
} else { } else {
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
} }
""" """
} }
...@@ -265,7 +261,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ...@@ -265,7 +261,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "%" override def symbol: String = "%"
override def decimalMethod: String = "reminder" override def decimalMethod: String = "remainder"
override def nullable: Boolean = true override def nullable: Boolean = true
...@@ -305,11 +301,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ...@@ -305,11 +301,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else { } else {
s"${eval2.primitive} == 0" s"${eval2.primitive} == 0"
} }
val method = if (left.dataType.isInstanceOf[DecimalType]) { val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
s".$decimalMethod" val javaType = ctx.javaType(left.dataType)
} else {
s"$symbol"
}
eval1.code + eval2.code + eval1.code + eval2.code +
s""" s"""
boolean ${ev.isNull} = false; boolean ${ev.isNull} = false;
...@@ -317,7 +310,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ...@@ -317,7 +310,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
if (${eval1.isNull} || ${eval2.isNull} || $test) { if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true; ${ev.isNull} = true;
} else { } else {
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
} }
""" """
} }
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions package org.apache.spark.sql.catalyst.expressions
import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
...@@ -26,100 +24,103 @@ import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} ...@@ -26,100 +24,103 @@ import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("arithmetic") { /**
val row = create_row(1, 2, 3, null) * Runs through the testFunc for all numeric data types.
val c1 = 'a.int.at(0) *
val c2 = 'a.int.at(1) * @param testFunc a test function that accepts a conversion function to convert an integer
val c3 = 'a.int.at(2) * into another data type.
val c4 = 'a.int.at(3) */
private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = {
checkEvaluation(UnaryMinus(c1), -1, row) testFunc(_.toByte)
checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) testFunc(_.toShort)
testFunc(identity)
checkEvaluation(Add(c1, c4), null, row) testFunc(_.toLong)
checkEvaluation(Add(c1, c2), 3, row) testFunc(_.toFloat)
checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) testFunc(_.toDouble)
checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) testFunc(Decimal(_))
checkEvaluation(
Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
checkEvaluation(-c1, -1, row)
checkEvaluation(c1 + c2, 3, row)
checkEvaluation(c1 - c2, -1, row)
checkEvaluation(c1 * c2, 2, row)
checkEvaluation(c1 / c2, 0, row)
checkEvaluation(c1 % c2, 1, row)
} }
test("fractional arithmetic") { test("+ (Add)") {
val row = create_row(1.1, 2.0, 3.1, null) testNumericDataTypes { convert =>
val c1 = 'a.double.at(0) val left = Literal(convert(1))
val c2 = 'a.double.at(1) val right = Literal(convert(2))
val c3 = 'a.double.at(2) checkEvaluation(Add(left, right), convert(3))
val c4 = 'a.double.at(3) checkEvaluation(Add(Literal.create(null, left.dataType), right), null)
checkEvaluation(Add(left, Literal.create(null, right.dataType)), null)
checkEvaluation(UnaryMinus(c1), -1.1, row) }
checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0)
checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3.1, row)
checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row)
checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row)
checkEvaluation(
Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row)
checkEvaluation(-c1, -1.1, row)
checkEvaluation(c1 + c2, 3.1, row)
checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
} }
test("Abs") { test("- (UnaryMinus)") {
def testAbs(convert: (Int) => Any): Unit = { testNumericDataTypes { convert =>
checkEvaluation(Abs(Literal(convert(0))), convert(0)) val input = Literal(convert(1))
checkEvaluation(Abs(Literal(convert(1))), convert(1)) val dataType = input.dataType
checkEvaluation(Abs(Literal(convert(-1))), convert(1)) checkEvaluation(UnaryMinus(input), convert(-1))
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
} }
testAbs(_.toByte)
testAbs(_.toShort)
testAbs(identity)
testAbs(_.toLong)
testAbs(_.toFloat)
testAbs(_.toDouble)
testAbs(Decimal(_))
} }
test("Divide") { test("- (Minus)") {
checkEvaluation(Divide(Literal(2), Literal(1)), 2) testNumericDataTypes { convert =>
checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) val left = Literal(convert(1))
val right = Literal(convert(2))
checkEvaluation(Subtract(left, right), convert(-1))
checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null)
checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null)
}
}
test("* (Multiply)") {
testNumericDataTypes { convert =>
val left = Literal(convert(1))
val right = Literal(convert(2))
checkEvaluation(Multiply(left, right), convert(2))
checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null)
checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null)
}
}
test("/ (Divide) basic") {
testNumericDataTypes { convert =>
val left = Literal(convert(2))
val right = Literal(convert(1))
val dataType = left.dataType
checkEvaluation(Divide(left, right), convert(2))
checkEvaluation(Divide(Literal.create(null, dataType), right), null)
checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null)
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero
}
}
test("/ (Divide) for integral type") {
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
checkEvaluation(Divide(Literal(1), Literal(2)), 0) checkEvaluation(Divide(Literal(1), Literal(2)), 0)
checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null)
checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null)
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null)
checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null)
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)),
null)
} }
test("Remainder") { test("/ (Divide) for floating point") {
checkEvaluation(Remainder(Literal(2), Literal(1)), 0) checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
checkEvaluation(Remainder(Literal(1), Literal(2)), 1) checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
checkEvaluation(Remainder(Literal(1), Literal(0)), null) }
checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null)
checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) test("% (Remainder)") {
checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) testNumericDataTypes { convert =>
checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) val left = Literal(convert(1))
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) val right = Literal(convert(2))
checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) checkEvaluation(Remainder(left, right), convert(1))
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null)
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null)
null) checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
}
}
test("Abs") {
testNumericDataTypes { convert =>
checkEvaluation(Abs(Literal(convert(0))), convert(0))
checkEvaluation(Abs(Literal(convert(1))), convert(1))
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
}
} }
test("MaxOf") { test("MaxOf") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment