Skip to content
Snippets Groups Projects
Commit c5f745ed authored by Wenchen Fan's avatar Wenchen Fan Committed by Davies Liu
Browse files

[SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen

simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible.

generated code comparison for `hash(int, double, string, array<string>)`:
**before:**
```
  public UnsafeRow apply(InternalRow i) {
    /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
    int value1 = 42;
    /* input[0, int] */
    int value3 = i.getInt(0);
    if (!false) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
    }
    /* input[1, double] */
    double value5 = i.getDouble(1);
    if (!false) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
    }
    /* input[2, string] */
    boolean isNull6 = i.isNullAt(2);
    UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));
    if (!isNull6) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
    }
    /* input[3, array<int>] */
    boolean isNull8 = i.isNullAt(3);
    ArrayData value9 = isNull8 ? null : (i.getArray(3));
    if (!isNull8) {
      int result10 = value1;
      for (int index11 = 0; index11 < value9.numElements(); index11++) {
        if (!value9.isNullAt(index11)) {
          final int element12 = value9.getInt(index11);
          result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10);
        }
      }
      value1 = result10;
    }
  }
```
**after:**
```
  public UnsafeRow apply(InternalRow i) {
    /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
    int value1 = 42;
    /* input[0, int] */
    int value3 = i.getInt(0);
    value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
    /* input[1, double] */
    double value5 = i.getDouble(1);
    value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
    /* input[2, string] */
    boolean isNull6 = i.isNullAt(2);
    UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));

    if (!isNull6) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
    }

    /* input[3, array<int>] */
    boolean isNull8 = i.isNullAt(3);
    ArrayData value9 = isNull8 ? null : (i.getArray(3));
    if (!isNull8) {
      for (int index10 = 0; index10 < value9.numElements(); index10++) {
        final int element11 = value9.getInt(index10);
        value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1);
      }
    }

    rowWriter14.write(0, value1);
    return result12;
  }
```

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10974 from cloud-fan/codegen.
parent e4c1162b
No related branches found
No related tags found
No related merge requests found
......@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
val childrenHash = children.zipWithIndex.map {
case (child, dt) =>
val childGen = child.gen(ctx)
val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
s"""
${childGen.code}
if (!${childGen.isNull}) {
${childHash.code}
${ev.value} = ${childHash.value};
}
"""
val childrenHash = children.map { child =>
val childGen = child.gen(ctx)
childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
}.mkString("\n")
s"""
int ${ev.value} = $seed;
$childrenHash
"""
}
private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
if (nullable) {
s"""
if (!$isNull) {
$execution
}
"""
} else {
"\n" + execution
}
}
private def nullSafeElementHash(
input: String,
index: String,
nullable: Boolean,
elementType: DataType,
result: String,
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
generateNullCheck(nullable, s"$input.isNullAt($index)") {
s"""
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
}
private def computeHash(
input: String,
dataType: DataType,
seed: String,
ctx: CodegenContext): ExprCode = {
result: String,
ctx: CodegenContext): String = {
val hasher = classOf[Murmur3_x86_32].getName
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
def hashBytes(b: String): String =
s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"
dataType match {
case NullType => inlineValue(seed)
case NullType => ""
case BooleanType => hashInt(s"$input ? 1 : 0")
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
case LongType | TimestampType => hashLong(input)
......@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
hashLong(s"$input.toUnscaledLong()")
} else {
val bytes = ctx.freshName("bytes")
val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
val offset = "Platform.BYTE_ARRAY_OFFSET"
val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
ExprCode(code, "false", result)
s"""
final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
${hashBytes(bytes)}
"""
}
case CalendarIntervalType =>
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
inlineValue(monthsHash)
case BinaryType =>
val offset = "Platform.BYTE_ARRAY_OFFSET"
inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
case BinaryType => hashBytes(input)
case StringType =>
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
case ArrayType(et, _) =>
val result = ctx.freshName("result")
case ArrayType(et, containsNull) =>
val index = ctx.freshName("index")
val element = ctx.freshName("element")
val elementHash = computeHash(element, et, result, ctx)
val code =
s"""
int $result = $seed;
for (int $index = 0; $index < $input.numElements(); $index++) {
if (!$input.isNullAt($index)) {
final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
${elementHash.code}
$result = ${elementHash.value};
}
}
"""
ExprCode(code, "false", result)
s"""
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
}
"""
case MapType(kt, vt, _) =>
val result = ctx.freshName("result")
case MapType(kt, vt, valueContainsNull) =>
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val key = ctx.freshName("key")
val value = ctx.freshName("value")
val keyHash = computeHash(key, kt, result, ctx)
val valueHash = computeHash(value, vt, result, ctx)
val code =
s"""
int $result = $seed;
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
for (int $index = 0; $index < $input.numElements(); $index++) {
final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
${keyHash.code}
$result = ${keyHash.value};
if (!$values.isNullAt($index)) {
final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
${valueHash.code}
$result = ${valueHash.value};
}
}
"""
ExprCode(code, "false", result)
s"""
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(keys, index, false, kt, result, ctx)}
${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
}
"""
case StructType(fields) =>
val result = ctx.freshName("result")
val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
case (dt, index) =>
val field = ctx.freshName("field")
val fieldHash = computeHash(field, dt, result, ctx)
s"""
if (!$input.isNullAt($index)) {
final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
${fieldHash.code}
$result = ${fieldHash.value};
}
"""
fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}.mkString("\n")
val code =
s"""
int $result = $seed;
$fieldsHash
"""
ExprCode(code, "false", result)
case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment