From a9a6b80c718008aac7c411dfe46355efe58dee2e Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Wed, 11 Nov 2015 12:48:51 -0800
Subject: [PATCH] [SPARK-11645][SQL] Remove OpenHashSet for the old aggregate.

Author: Reynold Xin <rxin@databricks.com>

Closes #9621 from rxin/SPARK-11645.
---
 .../expressions/codegen/CodeGenerator.scala   |   6 -
 .../codegen/GenerateUnsafeProjection.scala    |   7 +-
 .../spark/sql/catalyst/expressions/sets.scala | 194 ------------------
 .../sql/execution/SparkSqlSerializer.scala    | 103 +---------
 .../spark/sql/UserDefinedTypeSuite.scala      |  11 -
 5 files changed, 5 insertions(+), 316 deletions(-)
 delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5a4bba232b..ccd91d3549 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.types._
 
 
-// These classes are here to avoid issues with serialization and integration with quasiquotes.
-class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
-class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
-
 /**
  * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
  *
@@ -205,8 +201,6 @@ class CodeGenContext {
     case _: StructType => "InternalRow"
     case _: ArrayType => "ArrayData"
     case _: MapType => "MapData"
-    case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
-    case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
     case udt: UserDefinedType[_] => javaType(udt.sqlType)
     case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
     case ObjectType(cls) => cls.getName
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 9ef2261414..4c17d02a23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
     case t: ArrayType if canSupport(t.elementType) => true
     case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
-    case dt: OpenHashSetUDT => false  // it's not a standard UDT
     case udt: UserDefinedType[_] => canSupport(udt.sqlType)
     case _ => false
   }
@@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     in.map(BindReferences.bindReference(_, inputSchema))
 
   def generate(
-    expressions: Seq[Expression],
-    subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+      expressions: Seq[Expression],
+      subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
     create(canonicalize(expressions), subexpressionEliminationEnabled)
   }
 
   protected def create(expressions: Seq[Expression]): UnsafeProjection = {
-    create(expressions, false)
+    create(expressions, subexpressionEliminationEnabled = false)
   }
 
   private def create(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
deleted file mode 100644
index d124d29d53..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
-
-/** The data type for expressions returning an OpenHashSet as the result. */
-private[sql] class OpenHashSetUDT(
-    val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
-
-  override def sqlType: DataType = ArrayType(elementType)
-
-  /** Since we are using OpenHashSet internally, usually it will not be called. */
-  override def serialize(obj: Any): Seq[Any] = {
-    obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
-  }
-
-  /** Since we are using OpenHashSet internally, usually it will not be called. */
-  override def deserialize(datum: Any): OpenHashSet[Any] = {
-    val iterator = datum.asInstanceOf[Seq[Any]].iterator
-    val set = new OpenHashSet[Any]
-    while(iterator.hasNext) {
-      set.add(iterator.next())
-    }
-
-    set
-  }
-
-  override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
-
-  private[spark] override def asNullable: OpenHashSetUDT = this
-}
-
-/**
- * Creates a new set of the specified type
- */
-case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback {
-
-  override def nullable: Boolean = false
-
-  override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
-
-  override def eval(input: InternalRow): Any = {
-    new OpenHashSet[Any]()
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    elementType match {
-      case IntegerType | LongType =>
-        ev.isNull = "false"
-        s"""
-          ${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}();
-        """
-      case _ => super.genCode(ctx, ev)
-    }
-  }
-
-  override def toString: String = s"new Set($dataType)"
-}
-
-/**
- * Adds an item to a set.
- * For performance, this expression mutates its input during evaluation.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class AddItemToSet(item: Expression, set: Expression)
-  extends Expression with CodegenFallback {
-
-  override def children: Seq[Expression] = item :: set :: Nil
-
-  override def nullable: Boolean = set.nullable
-
-  override def dataType: DataType = set.dataType
-
-  override def eval(input: InternalRow): Any = {
-    val itemEval = item.eval(input)
-    val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
-
-    if (itemEval != null) {
-      if (setEval != null) {
-        setEval.add(itemEval)
-        setEval
-      } else {
-        null
-      }
-    } else {
-      setEval
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
-    elementType match {
-      case IntegerType | LongType =>
-        val itemEval = item.gen(ctx)
-        val setEval = set.gen(ctx)
-        val htype = ctx.javaType(dataType)
-
-        ev.isNull = "false"
-        ev.value = setEval.value
-        itemEval.code + setEval.code +  s"""
-          if (!${itemEval.isNull} && !${setEval.isNull}) {
-           (($htype)${setEval.value}).add(${itemEval.value});
-          }
-         """
-      case _ => super.genCode(ctx, ev)
-    }
-  }
-
-  override def toString: String = s"$set += $item"
-}
-
-/**
- * Combines the elements of two sets.
- * For performance, this expression mutates its left input set during evaluation.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class CombineSets(left: Expression, right: Expression)
-  extends BinaryExpression with CodegenFallback {
-
-  override def nullable: Boolean = left.nullable
-  override def dataType: DataType = left.dataType
-
-  override def eval(input: InternalRow): Any = {
-    val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
-    if(leftEval != null) {
-      val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
-      if (rightEval != null) {
-        val iterator = rightEval.iterator
-        while(iterator.hasNext) {
-          val rightValue = iterator.next()
-          leftEval.add(rightValue)
-        }
-      }
-      leftEval
-    } else {
-      null
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
-    elementType match {
-      case IntegerType | LongType =>
-        val leftEval = left.gen(ctx)
-        val rightEval = right.gen(ctx)
-        val htype = ctx.javaType(dataType)
-
-        ev.isNull = leftEval.isNull
-        ev.value = leftEval.value
-        leftEval.code + rightEval.code + s"""
-          if (!${leftEval.isNull} && !${rightEval.isNull}) {
-            ${leftEval.value}.union((${htype})${rightEval.value});
-          }
-        """
-      case _ => super.genCode(ctx, ev)
-    }
-  }
-}
-
-/**
- * Returns the number of elements in the input set.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback {
-
-  override def dataType: DataType = LongType
-
-  protected override def nullSafeEval(input: Any): Any =
-    input.asInstanceOf[OpenHashSet[Any]].size.toLong
-
-  override def toString: String = s"$child.count()"
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index b19ad4f1c5..8317f648cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap}
 
 import scala.reflect.ClassTag
 
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
 import com.esotericsoftware.kryo.io.{Input, Output}
 import com.esotericsoftware.kryo.{Kryo, Serializer}
 import com.twitter.chill.ResourcePool
 
 import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
 import org.apache.spark.sql.types.Decimal
 import org.apache.spark.util.MutablePair
-import org.apache.spark.util.collection.OpenHashSet
 import org.apache.spark.{SparkConf, SparkEnv}
 
+
 private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
   override def newKryo(): Kryo = {
     val kryo = super.newKryo()
@@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
-    kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
-                  new HyperLogLogSerializer)
     kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
     kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)
 
-    // Specific hashsets must come first TODO: Move to core.
-    kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
-    kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
-    kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
-                  new OpenHashSetSerializer)
     kryo.register(classOf[Decimal])
     kryo.register(classOf[JavaHashMap[_, _]])
 
@@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
 }
 
 private[execution] class KryoResourcePool(size: Int)
-    extends ResourcePool[SerializerInstance](size) {
+  extends ResourcePool[SerializerInstance](size) {
 
   val ser: SparkSqlSerializer = {
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
@@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
     new java.math.BigDecimal(input.readString())
   }
 }
-
-private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
-  def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
-    val bytes = hyperLogLog.getBytes()
-    output.writeInt(bytes.length)
-    output.writeBytes(bytes)
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
-    val length = input.readInt()
-    val bytes = input.readBytes(length)
-    HyperLogLog.Builder.build(bytes)
-  }
-}
-
-private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
-  def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
-    val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
-    output.writeInt(hs.size)
-    val iterator = hs.iterator
-    while(iterator.hasNext) {
-      val row = iterator.next()
-      rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
-    }
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
-    val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
-    val numItems = input.readInt()
-    val set = new OpenHashSet[Any](numItems + 1)
-    var i = 0
-    while (i < numItems) {
-      val row =
-        new GenericInternalRow(rowSerializer.read(
-          kryo,
-          input,
-          classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
-      set.add(row)
-      i += 1
-    }
-    set
-  }
-}
-
-private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
-  def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
-    output.writeInt(hs.size)
-    val iterator = hs.iterator
-    while(iterator.hasNext) {
-      val value: Int = iterator.next()
-      output.writeInt(value)
-    }
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
-    val numItems = input.readInt()
-    val set = new IntegerHashSet
-    var i = 0
-    while (i < numItems) {
-      val value = input.readInt()
-      set.add(value)
-      i += 1
-    }
-    set
-  }
-}
-
-private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
-  def write(kryo: Kryo, output: Output, hs: LongHashSet) {
-    output.writeInt(hs.size)
-    val iterator = hs.iterator
-    while(iterator.hasNext) {
-      val value = iterator.next()
-      output.writeLong(value)
-    }
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
-    val numItems = input.readInt()
-    val set = new LongHashSet
-    var i = 0
-    while (i < numItems) {
-      val value = input.readLong()
-      set.add(value)
-      i += 1
-    }
-    set
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index e31c528f3a..f602f2fb89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,7 +23,6 @@ import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT
 import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -131,15 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
     df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
   }
 
-  test("OpenHashSetUDT") {
-    val openHashSetUDT = new OpenHashSetUDT(IntegerType)
-    val set = new OpenHashSet[Int]
-    (1 to 10).foreach(i => set.add(i))
-
-    val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
-    assert(actual.iterator.toSet === set.iterator.toSet)
-  }
-
   test("UDTs with JSON") {
     val data = Seq(
       "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}",
@@ -163,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
   test("SPARK-10472 UserDefinedType.typeName") {
     assert(IntegerType.typeName === "integer")
     assert(new MyDenseVectorUDT().typeName === "mydensevector")
-    assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
   }
 
   test("Catalyst type converter null handling for UDTs") {
-- 
GitLab