From 6d8a6e4161176e391514153d7535da14b52194be Mon Sep 17 00:00:00 2001
From: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Date: Wed, 5 Aug 2015 00:54:31 -0700
Subject: [PATCH] [SPARK-9360] [SQL] Support BinaryType in PrefixComparators
 for UnsafeExternalSort

The current implementation of UnsafeExternalSort uses NoOpPrefixComparator for binary-typed data.
So, we need to add BinaryPrefixComparator in PrefixComparators.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #7676 from maropu/BinaryTypePrefixComparator and squashes the following commits:

fe6f31b [Takeshi YAMAMURO] Apply comments
d943c04 [Takeshi YAMAMURO] Add a codegen'd entry for BinaryType in SortPrefix
ecf3ac5 [Takeshi YAMAMURO] Support BinaryType in PrefixComparator
---
 .../unsafe/sort/PrefixComparators.java        | 35 +++++++++++++++++
 .../unsafe/sort/PrefixComparatorsSuite.scala  | 38 +++++++++++++++++++
 .../sql/catalyst/expressions/SortOrder.scala  |  3 ++
 .../spark/sql/execution/SortPrefixUtils.scala |  2 +
 4 files changed, 78 insertions(+)

diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 4d7e5b3dfb..b5f661c0d5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection.unsafe.sort;
 import com.google.common.primitives.UnsignedLongs;
 
 import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.types.UTF8String;
 import org.apache.spark.util.Utils;
 
@@ -29,6 +30,8 @@ public class PrefixComparators {
 
   public static final StringPrefixComparator STRING = new StringPrefixComparator();
   public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+  public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
+  public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
   public static final LongPrefixComparator LONG = new LongPrefixComparator();
   public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
   public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
@@ -52,6 +55,38 @@ public class PrefixComparators {
     }
   }
 
+  public static final class BinaryPrefixComparator extends PrefixComparator {
+    @Override
+    public int compare(long aPrefix, long bPrefix) {
+      return UnsignedLongs.compare(aPrefix, bPrefix);
+    }
+
+    public static long computePrefix(byte[] bytes) {
+      if (bytes == null) {
+        return 0L;
+      } else {
+        /**
+         * TODO: If a wrapper for BinaryType is created (SPARK-8786),
+         * these codes below will be in the wrapper class.
+         */
+        final int minLen = Math.min(bytes.length, 8);
+        long p = 0;
+        for (int i = 0; i < minLen; ++i) {
+          p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i))
+              << (56 - 8 * i);
+        }
+        return p;
+      }
+    }
+  }
+
+  public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+    @Override
+    public int compare(long bPrefix, long aPrefix) {
+      return UnsignedLongs.compare(aPrefix, bPrefix);
+    }
+  }
+
   public static final class LongPrefixComparator extends PrefixComparator {
     @Override
     public int compare(long a, long b) {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index 26a2e96eda..0326ed70b5 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -55,6 +55,44 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
     forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
   }
 
+  test("Binary prefix comparator") {
+
+     def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
+      for (i <- 0 until x.length; if i < y.length) {
+        val res = x(i).compare(y(i))
+        if (res != 0) return res
+      }
+      x.length - y.length
+    }
+
+    def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = {
+      val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x)
+      val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y)
+      val prefixComparisonResult =
+        PrefixComparators.BINARY.compare(s1Prefix, s2Prefix)
+      assert(
+        (prefixComparisonResult == 0) ||
+        (prefixComparisonResult < 0 && compareBinary(x, y) < 0) ||
+        (prefixComparisonResult > 0 && compareBinary(x, y) > 0))
+    }
+
+    // scalastyle:off
+    val regressionTests = Table(
+      ("s1", "s2"),
+      ("abc", "世界"),
+      ("你好", "世界"),
+      ("你好123", "你好122")
+    )
+    // scalastyle:on
+
+    forAll (regressionTests) { (s1: String, s2: String) =>
+      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
+    }
+    forAll { (s1: String, s2: String) =>
+      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
+    }
+  }
+
   test("double prefix comparator handles NaNs properly") {
     val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
     val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index f6a872ba44..98e029035a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
 
 abstract sealed class SortDirection
@@ -63,6 +64,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val childCode = child.child.gen(ctx)
     val input = childCode.primitive
+    val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
     val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
 
     val (nullValue: Long, prefixCode: String) = child.child.dataType match {
@@ -76,6 +78,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
         (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
           s"$DoublePrefixCmp.computePrefix((double)$input)")
       case StringType => (0L, s"$input.getPrefix()")
+      case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
       case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
         val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
           s"$input.toUnscaledLong()"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 49adf21537..e17b50edc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -38,6 +38,8 @@ object SortPrefixUtils {
     sortOrder.dataType match {
       case StringType =>
         if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC
+      case BinaryType =>
+        if (sortOrder.isAscending) PrefixComparators.BINARY else PrefixComparators.BINARY_DESC
       case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType =>
         if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
       case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
-- 
GitLab