Skip to content
Snippets Groups Projects
Commit 2758ff0a authored by Daoyuan Wang's avatar Daoyuan Wang Committed by Reynold Xin
Browse files

[SPARK-8217] [SQL] math function log2

Author: Daoyuan Wang <daoyuan.wang@intel.com>

This patch had conflicts when merged, resolved by
Committer: Reynold Xin <rxin@databricks.com>

Closes #6718 from adrian-wang/udflog2 and squashes the following commits:

3909f48 [Daoyuan Wang] math function: log2
parent 9fe3adcc
No related branches found
No related tags found
No related merge requests found
...@@ -111,6 +111,7 @@ object FunctionRegistry { ...@@ -111,6 +111,7 @@ object FunctionRegistry {
expression[Log10]("log10"), expression[Log10]("log10"),
expression[Log1p]("log1p"), expression[Log1p]("log1p"),
expression[Pi]("pi"), expression[Pi]("pi"),
expression[Log2]("log2"),
expression[Pow]("pow"), expression[Pow]("pow"),
expression[Rint]("rint"), expression[Rint]("rint"),
expression[Signum]("signum"), expression[Signum]("signum"),
......
...@@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO ...@@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
}
"""
}
}
case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
......
...@@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
} }
test("log2") {
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
}
test("pow") { test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
......
...@@ -1084,7 +1084,7 @@ object functions { ...@@ -1084,7 +1084,7 @@ object functions {
def log(columnName: String): Column = log(Column(columnName)) def log(columnName: String): Column = log(Column(columnName))
/** /**
* Computes the logarithm of the given value in Base 10. * Computes the logarithm of the given value in base 10.
* *
* @group math_funcs * @group math_funcs
* @since 1.4.0 * @since 1.4.0
...@@ -1092,7 +1092,7 @@ object functions { ...@@ -1092,7 +1092,7 @@ object functions {
def log10(e: Column): Column = Log10(e.expr) def log10(e: Column): Column = Log10(e.expr)
/** /**
* Computes the logarithm of the given value in Base 10. * Computes the logarithm of the given value in base 10.
* *
* @group math_funcs * @group math_funcs
* @since 1.4.0 * @since 1.4.0
...@@ -1124,6 +1124,22 @@ object functions { ...@@ -1124,6 +1124,22 @@ object functions {
*/ */
def pi(): Column = Pi() def pi(): Column = Pi()
/**
* Computes the logarithm of the given column in base 2.
*
* @group math_funcs
* @since 1.5.0
*/
def log2(expr: Column): Column = Log2(expr.expr)
/**
* Computes the logarithm of the given value in base 2.
*
* @group math_funcs
* @since 1.5.0
*/
def log2(columnName: String): Column = log2(Column(columnName))
/** /**
* Returns the value of the first argument raised to the power of the second argument. * Returns the value of the first argument raised to the power of the second argument.
* *
......
...@@ -128,5 +128,17 @@ class DataFrameFunctionsSuite extends QueryTest { ...@@ -128,5 +128,17 @@ class DataFrameFunctionsSuite extends QueryTest {
}) })
} }
test("log2 functions test") {
val df = Seq((1, 2)).toDF("a", "b")
checkAnswer(
df.select(log2("b") + log2("a")),
Row(1))
checkAnswer(
ctx.sql("SELECT LOG2(8)"),
Row(3))
checkAnswer(
ctx.sql("SELECT LOG2(null)"),
Row(null))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment