Skip to content
Snippets Groups Projects
Commit c6ba7cca authored by Daoyuan Wang's avatar Daoyuan Wang Committed by Reynold Xin
Browse files

[SPARK-8215] [SPARK-8212] [SQL] add leaf math expression for e and pi

Author: Daoyuan Wang <daoyuan.wang@intel.com>

Closes #6716 from adrian-wang/epi and squashes the following commits:

e2e8dbd [Daoyuan Wang] move tests
11b351c [Daoyuan Wang] add tests and remove pu
db331c9 [Daoyuan Wang] py style
599ddd8 [Daoyuan Wang] add py
e6783ef [Daoyuan Wang] register function
82d426e [Daoyuan Wang] add function entry
dbf3ab5 [Daoyuan Wang] add PI and E
parent e90035e6
No related branches found
No related tags found
No related merge requests found
...@@ -106,6 +106,7 @@ object FunctionRegistry { ...@@ -106,6 +106,7 @@ object FunctionRegistry {
expression[Cbrt]("cbrt"), expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"), expression[Ceil]("ceil"),
expression[Cos]("cos"), expression[Cos]("cos"),
expression[EulerNumber]("e"),
expression[Exp]("exp"), expression[Exp]("exp"),
expression[Expm1]("expm1"), expression[Expm1]("expm1"),
expression[Floor]("floor"), expression[Floor]("floor"),
...@@ -113,6 +114,7 @@ object FunctionRegistry { ...@@ -113,6 +114,7 @@ object FunctionRegistry {
expression[Log]("log"), expression[Log]("log"),
expression[Log10]("log10"), expression[Log10]("log10"),
expression[Log1p]("log1p"), expression[Log1p]("log1p"),
expression[Pi]("pi"),
expression[Pow]("pow"), expression[Pow]("pow"),
expression[Rint]("rint"), expression[Rint]("rint"),
expression[Signum]("signum"), expression[Signum]("signum"),
......
...@@ -20,9 +20,34 @@ package org.apache.spark.sql.catalyst.expressions ...@@ -20,9 +20,34 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.sql.types.{DataType, DoubleType}
/**
* A leaf expression specifically for math constants. Math constants expect no input.
* @param c The math constant.
* @param name The short name of the function
*/
abstract class LeafMathExpression(c: Double, name: String)
extends LeafExpression with Serializable {
self: Product =>
override def dataType: DataType = DoubleType
override def foldable: Boolean = true
override def nullable: Boolean = false
override def toString: String = s"$name()"
override def eval(input: Row): Any = c
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name;
"""
}
}
/** /**
* A unary expression specifically for math functions. Math Functions expect a specific type of * A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`. * input format, therefore these functions extend `ExpectsInputTypes`.
* @param f The math function.
* @param name The short name of the function * @param name The short name of the function
*/ */
abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class UnaryMathExpression(f: Double => Double, name: String)
...@@ -98,6 +123,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) ...@@ -98,6 +123,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Leaf math functions
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
case class EulerNumber() extends LeafMathExpression(math.E, "E")
case class Pi() extends LeafMathExpression(math.Pi, "PI")
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Unary math functions // Unary math functions
......
...@@ -22,6 +22,20 @@ import org.apache.spark.sql.types.DoubleType ...@@ -22,6 +22,20 @@ import org.apache.spark.sql.types.DoubleType
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
/**
* Used for testing leaf math expressions.
*
* @param e expression
* @param c The constants in scala.math
* @tparam T Generic type for primitives
*/
private def testLeaf[T](
e: () => Expression,
c: T): Unit = {
checkEvaluation(e(), c, EmptyRow)
checkEvaluation(e(), c, create_row(null))
}
/** /**
* Used for testing unary math expressions. * Used for testing unary math expressions.
* *
...@@ -74,6 +88,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -74,6 +88,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
} }
test("e") {
testLeaf(EulerNumber, math.E)
}
test("pi") {
testLeaf(Pi, math.Pi)
}
test("sin") { test("sin") {
testUnary(Sin, math.sin) testUnary(Sin, math.sin)
} }
......
...@@ -944,6 +944,15 @@ object functions { ...@@ -944,6 +944,15 @@ object functions {
*/ */
def cosh(columnName: String): Column = cosh(Column(columnName)) def cosh(columnName: String): Column = cosh(Column(columnName))
/**
* Returns the double value that is closer than any other to e, the base of the natural
* logarithms.
*
* @group math_funcs
* @since 1.5.0
*/
def e(): Column = EulerNumber()
/** /**
* Computes the exponential of the given value. * Computes the exponential of the given value.
* *
...@@ -1105,6 +1114,15 @@ object functions { ...@@ -1105,6 +1114,15 @@ object functions {
*/ */
def log1p(columnName: String): Column = log1p(Column(columnName)) def log1p(columnName: String): Column = log1p(Column(columnName))
/**
* Returns the double value that is closer than any other to pi, the ratio of the circumference
* of a circle to its diameter.
*
* @group math_funcs
* @since 1.5.0
*/
def pi(): Column = Pi()
/** /**
* Returns the value of the first argument raised to the power of the second argument. * Returns the value of the first argument raised to the power of the second argument.
* *
......
...@@ -85,6 +85,25 @@ class DataFrameFunctionsSuite extends QueryTest { ...@@ -85,6 +85,25 @@ class DataFrameFunctionsSuite extends QueryTest {
} }
} }
test("constant functions") {
checkAnswer(
testData2.select(e()).limit(1),
Row(scala.math.E)
)
checkAnswer(
testData2.select(pi()).limit(1),
Row(scala.math.Pi)
)
checkAnswer(
ctx.sql("SELECT E()"),
Row(scala.math.E)
)
checkAnswer(
ctx.sql("SELECT PI()"),
Row(scala.math.Pi)
)
}
test("bitwiseNOT") { test("bitwiseNOT") {
checkAnswer( checkAnswer(
testData2.select(bitwiseNOT($"a")), testData2.select(bitwiseNOT($"a")),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment