Skip to content
Snippets Groups Projects
Commit fee3438a authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Reynold Xin
Browse files

[SPARK-8218][SQL] Add binary log math function

JIRA: https://issues.apache.org/jira/browse/SPARK-8218

Because there is already `log` unary function defined, the binary log function is called `logarithm` for now.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #6725 from viirya/expr_binary_log and squashes the following commits:

bf96bd9 [Liang-Chi Hsieh] Compare log result in string.
102070d [Liang-Chi Hsieh] Round log result to better comparing in python test.
fd01863 [Liang-Chi Hsieh] For comments.
beed631 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
6089d11 [Liang-Chi Hsieh] Remove unnecessary override.
8cf37b7 [Liang-Chi Hsieh] For comments.
bc89597 [Liang-Chi Hsieh] For comments.
db7dc38 [Liang-Chi Hsieh] Use ctor instead of companion object.
0634ef7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
1750034 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
3d75bfc [Liang-Chi Hsieh] Fix scala style.
5b39c02 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
23c54a3 [Liang-Chi Hsieh] Fix scala style.
ebc9929 [Liang-Chi Hsieh] Let Logarithm accept one parameter too.
605574d [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
21c3bfd [Liang-Chi Hsieh] Fix scala style.
c6c187f [Liang-Chi Hsieh] For comments.
c795342 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
f373bac [Liang-Chi Hsieh] Add binary log expression.
parent 78a430ea
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
""" """
A collections of builtin functions A collections of builtin functions
""" """
import math
import sys import sys
if sys.version < "3": if sys.version < "3":
...@@ -143,7 +144,7 @@ _binary_mathfunctions = { ...@@ -143,7 +144,7 @@ _binary_mathfunctions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).', 'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.' 'pow': 'Returns the value of the first argument raised to the power of the second argument.',
} }
_window_functions = { _window_functions = {
...@@ -403,6 +404,21 @@ def when(condition, value): ...@@ -403,6 +404,21 @@ def when(condition, value):
return Column(jc) return Column(jc)
@since(1.4)
def log(col, base=math.e):
"""Returns the first argument-based logarithm of the second argument.
>>> df.select(log(df.age, 10.0).alias('ten')).map(lambda l: str(l.ten)[:7]).collect()
['0.30102', '0.69897']
>>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect()
['0.69314', '1.60943']
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.log(base, _to_java_column(col))
return Column(jc)
@since(1.4) @since(1.4)
def lag(col, count=1, default=None): def lag(col, count=1, default=None):
""" """
......
...@@ -112,6 +112,7 @@ object FunctionRegistry { ...@@ -112,6 +112,7 @@ object FunctionRegistry {
expression[Expm1]("expm1"), expression[Expm1]("expm1"),
expression[Floor]("floor"), expression[Floor]("floor"),
expression[Hypot]("hypot"), expression[Hypot]("hypot"),
expression[Logarithm]("log"),
expression[Log]("ln"), expression[Log]("ln"),
expression[Log10]("log10"), expression[Log10]("log10"),
expression[Log1p]("log1p"), expression[Log1p]("log1p"),
......
...@@ -255,3 +255,23 @@ case class Pow(left: Expression, right: Expression) ...@@ -255,3 +255,23 @@ case class Pow(left: Expression, right: Expression)
""" """
} }
} }
case class Logarithm(left: Expression, right: Expression)
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
def this(child: Expression) = {
this(EulerNumber(), child)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val logCode = if (left.isInstanceOf[EulerNumber]) {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
}
logCode + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
}
}
...@@ -204,4 +204,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -204,4 +204,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testBinary(Atan2, math.atan2) testBinary(Atan2, math.atan2)
} }
test("binary log") {
val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
domain.foreach { case (v1, v2) =>
checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
}
checkEvaluation(
Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
null,
create_row(null))
checkEvaluation(
Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
null,
create_row(null))
}
} }
...@@ -1083,6 +1083,22 @@ object functions { ...@@ -1083,6 +1083,22 @@ object functions {
*/ */
def log(columnName: String): Column = log(Column(columnName)) def log(columnName: String): Column = log(Column(columnName))
/**
* Returns the first argument-base logarithm of the second argument.
*
* @group math_funcs
* @since 1.4.0
*/
def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr)
/**
* Returns the first argument-base logarithm of the second argument.
*
* @group math_funcs
* @since 1.4.0
*/
def log(base: Double, columnName: String): Column = log(base, Column(columnName))
/** /**
* Computes the logarithm of the given value in base 10. * Computes the logarithm of the given value in base 10.
* *
......
...@@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest { ...@@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p) testOneToOneNonNegativeMathFunction(log1p, math.log1p)
} }
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
df.select(org.apache.spark.sql.functions.log("a"),
org.apache.spark.sql.functions.log(2.0, "a"),
org.apache.spark.sql.functions.log("b")),
Row(math.log(123), math.log(123) / math.log(2), null))
checkAnswer(
df.selectExpr("log(a)", "log(2.0, a)", "log(b)"),
Row(math.log(123), math.log(123) / math.log(2), null))
}
test("abs") { test("abs") {
val input = val input =
Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment