diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala new file mode 100644 index 0000000000000000000000000000000000000000..415ef4e4a37ec0d461bcf76f4106d77b90dd8b46 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -0,0 +1,788 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.security.{MessageDigest, NoSuchAlgorithmException} +import java.util.zip.CRC32 + +import scala.annotation.tailrec + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.Platform + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the expressions for hashing. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * A function that calculates an MD5 128-bit checksum and returns it as a hex string + * For input of type [[BinaryType]] + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns an MD5 128-bit checksum as a hex string of `expr`.", + extended = """ + Examples: + > SELECT _FUNC_('Spark'); + 8cde774d6f7333752ed72cacddb05126 + """) +case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + protected override def nullSafeEval(input: Any): Any = + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of `expr`. + SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256. + """, + extended = """ + Examples: + > SELECT _FUNC_('Spark', 256); + 529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b + """) +// scalastyle:on line.size.limit +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def nullable: Boolean = true + + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val bitLength = input2.asInstanceOf[Int] + val input = input1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + if ($eval2 == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update($eval1); + ${ev.value} = UTF8String.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if ($eval2 == 256 || $eval2 == 0) { + ${ev.value} = + UTF8String.fromString($digestUtils.sha256Hex($eval1)); + } else if ($eval2 == 384) { + ${ev.value} = + UTF8String.fromString($digestUtils.sha384Hex($eval1)); + } else if ($eval2 == 512) { + ${ev.value} = + UTF8String.fromString($digestUtils.sha512Hex($eval1)); + } else { + ${ev.isNull} = true; + } + """ + }) + } +} + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns a sha1 hash value as a hex string of the `expr`.", + extended = """ + Examples: + > SELECT _FUNC_('Spark'); + 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + """) +case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + protected override def nullSafeEval(input: Any): Any = + UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" + ) + } +} + +/** + * A function that computes a cyclic redundancy check value and returns it as a bigint + * For input of type [[BinaryType]] + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns a cyclic redundancy check value of the `expr` as a bigint.", + extended = """ + Examples: + > SELECT _FUNC_('Spark'); + 1557323817 + """) +case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + protected override def nullSafeEval(input: Any): Any = { + val checksum = new CRC32 + checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length) + checksum.getValue + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val CRC32 = "java.util.zip.CRC32" + val checksum = ctx.freshName("checksum") + nullSafeCodeGen(ctx, ev, value => { + s""" + $CRC32 $checksum = new $CRC32(); + $checksum.update($value, 0, $value.length); + ${ev.value} = $checksum.getValue(); + """ + }) + } +} + + +/** + * A function that calculates hash value for a group of expressions. Note that the `seed` argument + * is not exposed to users and should only be set inside spark SQL. + * + * The hash value for an expression depends on its type and seed: + * - null: seed + * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to + * hash this int with seed. + * - byte, short, int: use murmur3 to hash the input as int with seed. + * - long: use murmur3 to hash the long input with seed. + * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. + * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it. + * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash + * it. Else, turn it into bytes and hash it. + * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`. + * - binary: use murmur3 to hash the bytes with seed. + * - string: get the bytes of string and hash it. + * - array: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each element, and assign the element hash value + * to `result`. + * - map: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each key-value, and assign the key-value hash + * value to `result`. + * - struct: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each field, and assign the field hash value to + * `result`. + * + * Finally we aggregate the hash values for each expression by the same way of struct. + */ +abstract class HashExpression[E] extends Expression { + /** Seed of the HashExpression. */ + val seed: E + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def eval(input: InternalRow = null): Any = { + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash + } + + protected def computeHash(value: Any, dataType: DataType, seed: E): E + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.isNull = "false" + val childrenHash = children.map { child => + val childGen = child.genCode(ctx) + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } + }.mkString("\n") + + ev.copy(code = s""" + ${ctx.javaType(dataType)} ${ev.value} = $seed; + $childrenHash""") + } + + protected def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + + protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i, $result);" + + protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l, $result);" + + protected def genHashBytes(b: String, result: String): String = { + val offset = "Platform.BYTE_ARRAY_OFFSET" + s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" + } + + protected def genHashBoolean(input: String, result: String): String = + genHashInt(s"$input ? 1 : 0", result) + + protected def genHashFloat(input: String, result: String): String = + genHashInt(s"Float.floatToIntBits($input)", result) + + protected def genHashDouble(input: String, result: String): String = + genHashLong(s"Double.doubleToLongBits($input)", result) + + protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + genHashLong(s"$input.toUnscaledLong()", result) + } else { + val bytes = ctx.freshName("bytes") + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${genHashBytes(bytes, result)} + """ + } + } + + protected def genHashCalendarInterval(input: String, result: String): String = { + val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" + s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" + } + + protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + } + + protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, keyType, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)} + } + """ + } + + protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} + } + """ + } + + protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + }.mkString("\n") + } + + @tailrec + private def computeHashWithTailRec( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = dataType match { + case NullType => "" + case BooleanType => genHashBoolean(input, result) + case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) + case LongType | TimestampType => genHashLong(input, result) + case FloatType => genHashFloat(input, result) + case DoubleType => genHashDouble(input, result) + case d: DecimalType => genHashDecimal(ctx, d, input, result) + case CalendarIntervalType => genHashCalendarInterval(input, result) + case BinaryType => genHashBytes(input, result) + case StringType => genHashString(input, result) + case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) + case MapType(kt, vt, valueContainsNull) => + genHashForMap(ctx, input, result, kt, vt, valueContainsNull) + case StructType(fields) => genHashForStruct(ctx, input, result, fields) + case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) + } + + protected def computeHash( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) + + protected def hasherClassName: String +} + +/** + * Base class for interpreted hash functions. + */ +abstract class InterpretedHashFunction { + protected def hashInt(i: Int, seed: Long): Long + + protected def hashLong(l: Long, seed: Long): Long + + protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + + def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case d: Decimal => + val precision = dataType.asInstanceOf[DecimalType].precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(d.toUnscaledLong, seed) + } else { + val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray + hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) + } + case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) + case a: Array[Byte] => + hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) + case s: UTF8String => + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + var result = seed + var i = 0 + while (i < array.numElements()) { + result = hash(array.get(i, elementType), elementType, result) + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(kt, vt, _) => kt -> vt + } + val keys = map.keyArray() + val values = map.valueArray() + var result = seed + var i = 0 + while (i < map.numElements()) { + result = hash(keys.get(i, kt), kt, result) + result = hash(values.get(i, vt), vt, result) + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + var result = seed + var i = 0 + val len = struct.numFields + while (i < len) { + result = hash(struct.get(i, types(i)), types(i), result) + i += 1 + } + result + } + } +} + +/** + * A MurMur3 Hash expression. + * + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.", + extended = """ + Examples: + > SELECT _FUNC_('Spark', array(123), 2); + -1321691492 + """) +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hash" + + override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + Murmur3HashFunction.hash(value, dataType, seed).toInt + } +} + +object Murmur3HashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + Murmur3_x86_32.hashInt(i, seed.toInt) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + Murmur3_x86_32.hashLong(l, seed.toInt) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) + } +} + +/** + * A xxHash64 64-bit hash expression. + */ +case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { + def this(arguments: Seq[Expression]) = this(arguments, 42L) + + override def dataType: DataType = LongType + + override def prettyName: String = "xxHash" + + override protected def hasherClassName: String = classOf[XXH64].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { + XxHash64Function.hash(value, dataType, seed) + } +} + +object XxHash64Function extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) + + override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + XXH64.hashUnsafeBytes(base, offset, len, seed) + } +} + + +/** + * Simulates Hive's hashing function at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * + * We should use this hash function for both shuffle and bucket of Hive tables, so that + * we can guarantee shuffle and bucketing have same data distribution + * + * TODO: Support Decimal and date related types + */ +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") +case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { + override val seed = 0 + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hive-hash" + + override protected def hasherClassName: String = classOf[HiveHasher].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + HiveHashFunction.hash(value, dataType, seed).toInt + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.isNull = "false" + val childHash = ctx.freshName("childHash") + val childrenHash = children.map { child => + val childGen = child.genCode(ctx) + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, childHash, ctx) + } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + }.mkString(s"int $childHash = 0;", s"\n$childHash = 0;\n", "") + + ev.copy(code = s""" + ${ctx.javaType(dataType)} ${ev.value} = $seed; + $childrenHash""") + } + + override def eval(input: InternalRow = null): Int = { + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash + } + + override protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i);" + + override protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l);" + + override protected def genHashBytes(b: String, result: String): String = + s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + + override protected def genHashCalendarInterval(input: String, result: String): String = { + s""" + $result = (31 * $hasherClassName.hashInt($input.months)) + + $hasherClassName.hashLong($input.microseconds);" + """ + } + + override protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + } + + override protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + val childResult = ctx.freshName("childResult") + s""" + int $childResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $childResult = 0; + ${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)}; + $result = (31 * $result) + $childResult; + } + """ + } + + override protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val keyResult = ctx.freshName("keyResult") + val valueResult = ctx.freshName("valueResult") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + int $keyResult = 0; + int $valueResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $keyResult = 0; + ${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)} + $valueResult = 0; + ${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)} + $result += $keyResult ^ $valueResult; + } + """ + } + + override protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + val localResult = ctx.freshName("localResult") + val childResult = ctx.freshName("childResult") + fields.zipWithIndex.map { case (field, index) => + s""" + $childResult = 0; + ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, + childResult, ctx)} + $localResult = (31 * $localResult) + $childResult; + """ + }.mkString( + s""" + int $localResult = 0; + int $childResult = 0; + """, + "", + s"$result = (31 * $result) + $localResult;" + ) + } +} + +object HiveHashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + HiveHasher.hashInt(i) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + HiveHasher.hashLong(l) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + HiveHasher.hashUnsafeBytes(base, offset, len) + } + + override def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => 0 + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + + var result = 0 + var i = 0 + val length = array.numElements() + while (i < length) { + result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(_kt, _vt, _) => _kt -> _vt + } + val keys = map.keyArray() + val values = map.valueArray() + + var result = 0 + var i = 0 + val length = map.numElements() + while (i < length) { + result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + + var result = 0 + var i = 0 + val length = struct.numFields + while (i < length) { + result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + i += 1 + } + result + + case _ => super.hash(value, dataType, seed) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 2ce10ef13215e8977c5b39ed002a0b5f823baa55..a874a1cf37086e99566135fb9c3f79e503188cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,529 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import java.security.{MessageDigest, NoSuchAlgorithmException} -import java.util.zip.CRC32 - -import scala.annotation.tailrec - -import org.apache.commons.codec.digest.DigestUtils - import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32 -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import org.apache.spark.unsafe.Platform - -/** - * A function that calculates an MD5 128-bit checksum and returns it as a hex string - * For input of type [[BinaryType]] - */ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns an MD5 128-bit checksum as a hex string of `expr`.", - extended = """ - Examples: - > SELECT _FUNC_('Spark'); - 8cde774d6f7333752ed72cacddb05126 - """) -case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = Seq(BinaryType) - - protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") - } -} - -/** - * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) - * and returns it as a hex string. The first argument is the string or binary to be hashed. The - * second argument indicates the desired bit length of the result, which must have a value of 224, - * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If - * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or - * the hash length is not one of the permitted values, the return value is NULL. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of `expr`. - SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256. - """, - extended = """ - Examples: - > SELECT _FUNC_('Spark', 256); - 529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b - """) -// scalastyle:on line.size.limit -case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - override def nullable: Boolean = true - - override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val bitLength = input2.asInstanceOf[Int] - val input = input1.asInstanceOf[Array[Byte]] - bitLength match { - case 224 => - // DigestUtils doesn't support SHA-224 now - try { - val md = MessageDigest.getInstance("SHA-224") - md.update(input) - UTF8String.fromBytes(md.digest()) - } catch { - // SHA-224 is not supported on the system, return null - case noa: NoSuchAlgorithmException => null - } - case 256 | 0 => - UTF8String.fromString(DigestUtils.sha256Hex(input)) - case 384 => - UTF8String.fromString(DigestUtils.sha384Hex(input)) - case 512 => - UTF8String.fromString(DigestUtils.sha512Hex(input)) - case _ => null - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val digestUtils = "org.apache.commons.codec.digest.DigestUtils" - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - if ($eval2 == 224) { - try { - java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); - md.update($eval1); - ${ev.value} = UTF8String.fromBytes(md.digest()); - } catch (java.security.NoSuchAlgorithmException e) { - ${ev.isNull} = true; - } - } else if ($eval2 == 256 || $eval2 == 0) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha256Hex($eval1)); - } else if ($eval2 == 384) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha384Hex($eval1)); - } else if ($eval2 == 512) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha512Hex($eval1)); - } else { - ${ev.isNull} = true; - } - """ - }) - } -} - -/** - * A function that calculates a sha1 hash value and returns it as a hex string - * For input of type [[BinaryType]] or [[StringType]] - */ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns a sha1 hash value as a hex string of the `expr`.", - extended = """ - Examples: - > SELECT _FUNC_('Spark'); - 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c - """) -case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = Seq(BinaryType) - - protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" - ) - } -} - -/** - * A function that computes a cyclic redundancy check value and returns it as a bigint - * For input of type [[BinaryType]] - */ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns a cyclic redundancy check value of the `expr` as a bigint.", - extended = """ - Examples: - > SELECT _FUNC_('Spark'); - 1557323817 - """) -case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = LongType - - override def inputTypes: Seq[DataType] = Seq(BinaryType) - - protected override def nullSafeEval(input: Any): Any = { - val checksum = new CRC32 - checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length) - checksum.getValue - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val CRC32 = "java.util.zip.CRC32" - val checksum = ctx.freshName("checksum") - nullSafeCodeGen(ctx, ev, value => { - s""" - $CRC32 $checksum = new $CRC32(); - $checksum.update($value, 0, $value.length); - ${ev.value} = $checksum.getValue(); - """ - }) - } -} - - -/** - * A function that calculates hash value for a group of expressions. Note that the `seed` argument - * is not exposed to users and should only be set inside spark SQL. - * - * The hash value for an expression depends on its type and seed: - * - null: seed - * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to - * hash this int with seed. - * - byte, short, int: use murmur3 to hash the input as int with seed. - * - long: use murmur3 to hash the long input with seed. - * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. - * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it. - * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash - * it. Else, turn it into bytes and hash it. - * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`. - * - binary: use murmur3 to hash the bytes with seed. - * - string: get the bytes of string and hash it. - * - array: The `result` starts with seed, then use `result` as seed, recursively - * calculate hash value for each element, and assign the element hash value - * to `result`. - * - map: The `result` starts with seed, then use `result` as seed, recursively - * calculate hash value for each key-value, and assign the key-value hash - * value to `result`. - * - struct: The `result` starts with seed, then use `result` as seed, recursively - * calculate hash value for each field, and assign the field hash value to - * `result`. - * - * Finally we aggregate the hash values for each expression by the same way of struct. - */ -abstract class HashExpression[E] extends Expression { - /** Seed of the HashExpression. */ - val seed: E - - override def foldable: Boolean = children.forall(_.foldable) - - override def nullable: Boolean = false - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - override def eval(input: InternalRow): Any = { - var hash = seed - var i = 0 - val len = children.length - while (i < len) { - hash = computeHash(children(i).eval(input), children(i).dataType, hash) - i += 1 - } - hash - } - - protected def computeHash(value: Any, dataType: DataType, seed: E): E - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" - val childrenHash = children.map { child => - val childGen = child.genCode(ctx) - childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, ev.value, ctx) - } - }.mkString("\n") - - ev.copy(code = s""" - ${ctx.javaType(dataType)} ${ev.value} = $seed; - $childrenHash""") - } - - protected def nullSafeElementHash( - input: String, - index: String, - nullable: Boolean, - elementType: DataType, - result: String, - ctx: CodegenContext): String = { - val element = ctx.freshName("element") - - ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { - s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; - ${computeHash(element, elementType, result, ctx)} - """ - } - } - - protected def genHashInt(i: String, result: String): String = - s"$result = $hasherClassName.hashInt($i, $result);" - - protected def genHashLong(l: String, result: String): String = - s"$result = $hasherClassName.hashLong($l, $result);" - - protected def genHashBytes(b: String, result: String): String = { - val offset = "Platform.BYTE_ARRAY_OFFSET" - s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" - } - - protected def genHashBoolean(input: String, result: String): String = - genHashInt(s"$input ? 1 : 0", result) - - protected def genHashFloat(input: String, result: String): String = - genHashInt(s"Float.floatToIntBits($input)", result) - - protected def genHashDouble(input: String, result: String): String = - genHashLong(s"Double.doubleToLongBits($input)", result) - - protected def genHashDecimal( - ctx: CodegenContext, - d: DecimalType, - input: String, - result: String): String = { - if (d.precision <= Decimal.MAX_LONG_DIGITS) { - genHashLong(s"$input.toUnscaledLong()", result) - } else { - val bytes = ctx.freshName("bytes") - s""" - final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${genHashBytes(bytes, result)} - """ - } - } - - protected def genHashCalendarInterval(input: String, result: String): String = { - val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" - s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" - } - - protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - } - - protected def genHashForMap( - ctx: CodegenContext, - input: String, - result: String, - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean): String = { - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(keys, index, false, keyType, result, ctx)} - ${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)} - } - """ - } - - protected def genHashForArray( - ctx: CodegenContext, - input: String, - result: String, - elementType: DataType, - containsNull: Boolean): String = { - val index = ctx.freshName("index") - s""" - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} - } - """ - } - - protected def genHashForStruct( - ctx: CodegenContext, - input: String, - result: String, - fields: Array[StructField]): String = { - fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) - }.mkString("\n") - } - - @tailrec - private def computeHashWithTailRec( - input: String, - dataType: DataType, - result: String, - ctx: CodegenContext): String = dataType match { - case NullType => "" - case BooleanType => genHashBoolean(input, result) - case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) - case LongType | TimestampType => genHashLong(input, result) - case FloatType => genHashFloat(input, result) - case DoubleType => genHashDouble(input, result) - case d: DecimalType => genHashDecimal(ctx, d, input, result) - case CalendarIntervalType => genHashCalendarInterval(input, result) - case BinaryType => genHashBytes(input, result) - case StringType => genHashString(input, result) - case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) - case MapType(kt, vt, valueContainsNull) => - genHashForMap(ctx, input, result, kt, vt, valueContainsNull) - case StructType(fields) => genHashForStruct(ctx, input, result, fields) - case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) - } - - protected def computeHash( - input: String, - dataType: DataType, - result: String, - ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) - - protected def hasherClassName: String -} - -/** - * Base class for interpreted hash functions. - */ -abstract class InterpretedHashFunction { - protected def hashInt(i: Int, seed: Long): Long - - protected def hashLong(l: Long, seed: Long): Long - - protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long - - def hash(value: Any, dataType: DataType, seed: Long): Long = { - value match { - case null => seed - case b: Boolean => hashInt(if (b) 1 else 0, seed) - case b: Byte => hashInt(b, seed) - case s: Short => hashInt(s, seed) - case i: Int => hashInt(i, seed) - case l: Long => hashLong(l, seed) - case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) - case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) - case d: Decimal => - val precision = dataType.asInstanceOf[DecimalType].precision - if (precision <= Decimal.MAX_LONG_DIGITS) { - hashLong(d.toUnscaledLong, seed) - } else { - val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray - hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) - } - case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) - case a: Array[Byte] => - hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) - - case array: ArrayData => - val elementType = dataType match { - case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType - case ArrayType(et, _) => et - } - var result = seed - var i = 0 - while (i < array.numElements()) { - result = hash(array.get(i, elementType), elementType, result) - i += 1 - } - result - - case map: MapData => - val (kt, vt) = dataType match { - case udt: UserDefinedType[_] => - val mapType = udt.sqlType.asInstanceOf[MapType] - mapType.keyType -> mapType.valueType - case MapType(kt, vt, _) => kt -> vt - } - val keys = map.keyArray() - val values = map.valueArray() - var result = seed - var i = 0 - while (i < map.numElements()) { - result = hash(keys.get(i, kt), kt, result) - result = hash(values.get(i, vt), vt, result) - i += 1 - } - result - - case struct: InternalRow => - val types: Array[DataType] = dataType match { - case udt: UserDefinedType[_] => - udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray - case StructType(fields) => fields.map(_.dataType) - } - var result = seed - var i = 0 - val len = struct.numFields - while (i < len) { - result = hash(struct.get(i, types(i)), types(i), result) - i += 1 - } - result - } - } -} - -/** - * A MurMur3 Hash expression. - * - * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle - * and bucketing have same data distribution. - */ -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.", - extended = """ - Examples: - > SELECT _FUNC_('Spark', array(123), 2); - -1321691492 - """) -case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { - def this(arguments: Seq[Expression]) = this(arguments, 42) - - override def dataType: DataType = IntegerType - - override def prettyName: String = "hash" - - override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName - - override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - Murmur3HashFunction.hash(value, dataType, seed).toInt - } -} - -object Murmur3HashFunction extends InterpretedHashFunction { - override protected def hashInt(i: Int, seed: Long): Long = { - Murmur3_x86_32.hashInt(i, seed.toInt) - } - - override protected def hashLong(l: Long, seed: Long): Long = { - Murmur3_x86_32.hashLong(l, seed.toInt) - } - - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { - Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) - } -} /** * Print the result of an expression to stderr (used for debugging codegen). @@ -608,33 +88,6 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def sql: String = s"assert_true(${child.sql})" } -/** - * A xxHash64 64-bit hash expression. - */ -case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { - def this(arguments: Seq[Expression]) = this(arguments, 42L) - - override def dataType: DataType = LongType - - override def prettyName: String = "xxHash" - - override protected def hasherClassName: String = classOf[XXH64].getName - - override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { - XxHash64Function.hash(value, dataType, seed) - } -} - -object XxHash64Function extends InterpretedHashFunction { - override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) - - override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { - XXH64.hashUnsafeBytes(base, offset, len, seed) - } -} - /** * Returns the current database of the SessionCatalog. */ @@ -651,217 +104,3 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def nullable: Boolean = false override def prettyName: String = "current_database" } - -/** - * Simulates Hive's hashing function at - * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive - * - * We should use this hash function for both shuffle and bucket of Hive tables, so that - * we can guarantee shuffle and bucketing have same data distribution - * - * TODO: Support Decimal and date related types - */ -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") -case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { - override val seed = 0 - - override def dataType: DataType = IntegerType - - override def prettyName: String = "hive-hash" - - override protected def hasherClassName: String = classOf[HiveHasher].getName - - override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, seed).toInt - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" - val childHash = ctx.freshName("childHash") - val childrenHash = children.map { child => - val childGen = child.genCode(ctx) - childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" - }.mkString(s"int $childHash = 0;", s"\n$childHash = 0;\n", "") - - ev.copy(code = s""" - ${ctx.javaType(dataType)} ${ev.value} = $seed; - $childrenHash""") - } - - override def eval(input: InternalRow): Int = { - var hash = seed - var i = 0 - val len = children.length - while (i < len) { - hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash) - i += 1 - } - hash - } - - override protected def genHashInt(i: String, result: String): String = - s"$result = $hasherClassName.hashInt($i);" - - override protected def genHashLong(l: String, result: String): String = - s"$result = $hasherClassName.hashLong($l);" - - override protected def genHashBytes(b: String, result: String): String = - s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" - - override protected def genHashCalendarInterval(input: String, result: String): String = { - s""" - $result = (31 * $hasherClassName.hashInt($input.months)) + - $hasherClassName.hashLong($input.microseconds);" - """ - } - - override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" - } - - override protected def genHashForArray( - ctx: CodegenContext, - input: String, - result: String, - elementType: DataType, - containsNull: Boolean): String = { - val index = ctx.freshName("index") - val childResult = ctx.freshName("childResult") - s""" - int $childResult = 0; - for (int $index = 0; $index < $input.numElements(); $index++) { - $childResult = 0; - ${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)}; - $result = (31 * $result) + $childResult; - } - """ - } - - override protected def genHashForMap( - ctx: CodegenContext, - input: String, - result: String, - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean): String = { - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - val keyResult = ctx.freshName("keyResult") - val valueResult = ctx.freshName("valueResult") - s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - int $keyResult = 0; - int $valueResult = 0; - for (int $index = 0; $index < $input.numElements(); $index++) { - $keyResult = 0; - ${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)} - $valueResult = 0; - ${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)} - $result += $keyResult ^ $valueResult; - } - """ - } - - override protected def genHashForStruct( - ctx: CodegenContext, - input: String, - result: String, - fields: Array[StructField]): String = { - val localResult = ctx.freshName("localResult") - val childResult = ctx.freshName("childResult") - fields.zipWithIndex.map { case (field, index) => - s""" - $childResult = 0; - ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, - childResult, ctx)} - $localResult = (31 * $localResult) + $childResult; - """ - }.mkString( - s""" - int $localResult = 0; - int $childResult = 0; - """, - "", - s"$result = (31 * $result) + $localResult;" - ) - } -} - -object HiveHashFunction extends InterpretedHashFunction { - override protected def hashInt(i: Int, seed: Long): Long = { - HiveHasher.hashInt(i) - } - - override protected def hashLong(l: Long, seed: Long): Long = { - HiveHasher.hashLong(l) - } - - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { - HiveHasher.hashUnsafeBytes(base, offset, len) - } - - override def hash(value: Any, dataType: DataType, seed: Long): Long = { - value match { - case null => 0 - case array: ArrayData => - val elementType = dataType match { - case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType - case ArrayType(et, _) => et - } - - var result = 0 - var i = 0 - val length = array.numElements() - while (i < length) { - result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt - i += 1 - } - result - - case map: MapData => - val (kt, vt) = dataType match { - case udt: UserDefinedType[_] => - val mapType = udt.sqlType.asInstanceOf[MapType] - mapType.keyType -> mapType.valueType - case MapType(_kt, _vt, _) => _kt -> _vt - } - val keys = map.keyArray() - val values = map.valueArray() - - var result = 0 - var i = 0 - val length = map.numElements() - while (i < length) { - result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt - i += 1 - } - result - - case struct: InternalRow => - val types: Array[DataType] = dataType match { - case udt: UserDefinedType[_] => - udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray - case StructType(fields) => fields.map(_.dataType) - } - - var result = 0 - var i = 0 - val length = struct.numFields - while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt - i += 1 - } - result - - case _ => super.hash(value, dataType, seed) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..c714bc03dc0d5b8e149db3697bdb49b4f28dada2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.nio.charset.StandardCharsets + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.types._ + +class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("md5") { + checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "6ac1e56bc78f031059be7be854522c4c") + checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) + } + + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))), + "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) + } + + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)), + DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), + Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } + + test("crc32") { + checkEvaluation(Crc32(Literal("ABC".getBytes(StandardCharsets.UTF_8))), 2743272264L) + checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + 2180413220L) + checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) + } + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + testHash( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) + + testHash( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) + + testHash( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + testHash( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + private def testHash(inputSchema: StructType): Unit = { + val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get + val encoder = RowEncoder(inputSchema) + val seed = scala.util.Random.nextInt() + test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { + for (_ <- 1 to 10) { + val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { + case (value, dt) => Literal.create(value, dt) + } + // Only test the interpreted version has same result with codegen version. + checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) + checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) + checkEvaluation(HiveHash(literals), HiveHash(literals).eval()) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 13ce5884620287cba30dc2c7fb28ede7979fd740..ed82efe7be2e8e574b08950f7e70247ef538627a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,58 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.nio.charset.StandardCharsets - -import org.apache.commons.codec.digest.DigestUtils - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("md5") { - checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), - "902fbdd2b1df0c4f70b4a5d23525e932") - checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), - "6ac1e56bc78f031059be7be854522c4c") - checkEvaluation(Md5(Literal.create(null, BinaryType)), null) - checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) - } - - test("sha1") { - checkEvaluation(Sha1(Literal("ABC".getBytes(StandardCharsets.UTF_8))), - "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") - checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), - "5d211bad8f4ee70e16c7d343a838fc344a1ed961") - checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) - checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))), - "da39a3ee5e6b4b0d3255bfef95601890afd80709") - checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) - } - - test("sha2") { - checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)), - DigestUtils.sha256Hex("ABC")) - checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), - DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) - // unsupported bit length - checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) - checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) - checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), - Literal.create(null, IntegerType)), null) - checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) - } - - test("crc32") { - checkEvaluation(Crc32(Literal("ABC".getBytes(StandardCharsets.UTF_8))), 2743272264L) - checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), - 2180413220L) - checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) - checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) - } - test("assert_true") { intercept[RuntimeException] { checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) @@ -86,76 +39,4 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) } - private val structOfString = new StructType().add("str", StringType) - private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) - private val arrayOfString = ArrayType(StringType) - private val arrayOfNull = ArrayType(NullType) - private val mapOfString = MapType(StringType, StringType) - private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) - - testHash( - new StructType() - .add("null", NullType) - .add("boolean", BooleanType) - .add("byte", ByteType) - .add("short", ShortType) - .add("int", IntegerType) - .add("long", LongType) - .add("float", FloatType) - .add("double", DoubleType) - .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) - .add("smallDecimal", DecimalType.USER_DEFAULT) - .add("string", StringType) - .add("binary", BinaryType) - .add("date", DateType) - .add("timestamp", TimestampType) - .add("udt", new ExamplePointUDT)) - - testHash( - new StructType() - .add("arrayOfNull", arrayOfNull) - .add("arrayOfString", arrayOfString) - .add("arrayOfArrayOfString", ArrayType(arrayOfString)) - .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) - .add("arrayOfMap", ArrayType(mapOfString)) - .add("arrayOfStruct", ArrayType(structOfString)) - .add("arrayOfUDT", arrayOfUDT)) - - testHash( - new StructType() - .add("mapOfIntAndString", MapType(IntegerType, StringType)) - .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) - .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) - .add("mapOfArray", MapType(arrayOfString, arrayOfString)) - .add("mapOfStringAndStruct", MapType(StringType, structOfString)) - .add("mapOfStructAndString", MapType(structOfString, StringType)) - .add("mapOfStruct", MapType(structOfString, structOfString))) - - testHash( - new StructType() - .add("structOfString", structOfString) - .add("structOfStructOfString", new StructType().add("struct", structOfString)) - .add("structOfArray", new StructType().add("array", arrayOfString)) - .add("structOfMap", new StructType().add("map", mapOfString)) - .add("structOfArrayAndMap", - new StructType().add("array", arrayOfString).add("map", mapOfString)) - .add("structOfUDT", structOfUDT)) - - private def testHash(inputSchema: StructType): Unit = { - val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get - val encoder = RowEncoder(inputSchema) - val seed = scala.util.Random.nextInt() - test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { - for (_ <- 1 to 10) { - val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] - val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { - case (value, dt) => Literal.create(value, dt) - } - // Only test the interpreted version has same result with codegen version. - checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) - checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) - checkEvaluation(HiveHash(literals), HiveHash(literals).eval()) - } - } - } }