From 5949e6c4477fd3cb07a6962dbee48b4416ea65dd Mon Sep 17 00:00:00 2001
From: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Date: Thu, 9 Mar 2017 22:58:52 -0800
Subject: [PATCH] [SPARK-19008][SQL] Improve performance of Dataset.map by
 eliminating boxing/unboxing

## What changes were proposed in this pull request?

This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](https://github.com/apache/spark/pull/16391#discussion_r93788919) with cloud-fan.

Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method.

This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`.

The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program.

Without this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3  3.20GHz
back-to-back map:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
RDD                                           1923 / 1952         52.0          19.2       1.0X
DataFrame                                      526 /  548        190.2           5.3       3.7X
Dataset                                       3094 / 3154         32.3          30.9       0.6X
```

With this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3  3.20GHz
back-to-back map:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
RDD                                           1883 / 1892         53.1          18.8       1.0X
DataFrame                                      502 /  642        199.1           5.0       3.7X
Dataset                                        657 /  784        152.2           6.6       2.9X
```

```java
  def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
    import spark.implicits._
    val rdd = spark.sparkContext.range(0, numRows)
    val ds = spark.range(0, numRows)
    val func = (l: Long) => l + 1
    val benchmark = new Benchmark("back-to-back map", numRows)
...
    benchmark.addCase("Dataset") { iter =>
      var res = ds.as[Long]
      var i = 0
      while (i < numChains) {
        res = res.map(func)
        i += 1
      }
      res.queryExecution.toRdd.foreach(_ => Unit)
    }
    benchmark
  }
```

A motivating example
```java
Seq(1, 2, 3).toDS.map(i => i * 7).show
```

Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow deserializetoobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */   private int mapelements_argValue;
/* 013 */   private UnsafeRow mapelements_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */   private UnsafeRow serializefromobject_result;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */   public GeneratedIterator(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     deserializetoobject_result = new UnsafeRow(1);
/* 029 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */     mapelements_result = new UnsafeRow(1);
/* 033 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */     serializefromobject_result = new UnsafeRow(1);
/* 036 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */   }
/* 040 */
/* 041 */   protected void processNext() throws java.io.IOException {
/* 042 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */       int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */       boolean mapelements_isNull = true;
/* 047 */       int mapelements_value = -1;
/* 048 */       if (!false) {
/* 049 */         mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */         mapelements_isNull = false;
/* 052 */         if (!mapelements_isNull) {
/* 053 */           Object mapelements_funcResult = null;
/* 054 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 055 */           if (mapelements_funcResult == null) {
/* 056 */             mapelements_isNull = true;
/* 057 */           } else {
/* 058 */             mapelements_value = (Integer) mapelements_funcResult;
/* 059 */           }
/* 060 */
/* 061 */         }
/* 062 */
/* 063 */       }
/* 064 */
/* 065 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 066 */
/* 067 */       if (mapelements_isNull) {
/* 068 */         serializefromobject_rowWriter.setNullAt(0);
/* 069 */       } else {
/* 070 */         serializefromobject_rowWriter.write(0, mapelements_value);
/* 071 */       }
/* 072 */       append(serializefromobject_result);
/* 073 */       if (shouldStop()) return;
/* 074 */     }
/* 075 */   }
/* 076 */ }
```

Generated code with this PR (lines 48-56 are changed)
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow deserializetoobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */   private int mapelements_argValue;
/* 013 */   private UnsafeRow mapelements_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */   private UnsafeRow serializefromobject_result;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */   public GeneratedIterator(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     deserializetoobject_result = new UnsafeRow(1);
/* 029 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */     mapelements_result = new UnsafeRow(1);
/* 033 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */     serializefromobject_result = new UnsafeRow(1);
/* 036 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */   }
/* 040 */
/* 041 */   protected void processNext() throws java.io.IOException {
/* 042 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */       int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */       boolean mapelements_isNull = true;
/* 047 */       int mapelements_value = -1;
/* 048 */       if (!false) {
/* 049 */         mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */         mapelements_isNull = false;
/* 052 */         if (!mapelements_isNull) {
/* 053 */           mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue);
/* 054 */         }
/* 055 */
/* 056 */       }
/* 057 */
/* 058 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 059 */
/* 060 */       if (mapelements_isNull) {
/* 061 */         serializefromobject_rowWriter.setNullAt(0);
/* 062 */       } else {
/* 063 */         serializefromobject_rowWriter.write(0, mapelements_value);
/* 064 */       }
/* 065 */       append(serializefromobject_result);
/* 066 */       if (shouldStop()) return;
/* 067 */     }
/* 068 */   }
/* 069 */ }
```

Java bytecode for methods for `i => i * 7`
```java
$ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class
Compiled from "Test.scala"
public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable {
  public static final long serialVersionUID;

  public final int apply(int);
    Code:
       0: aload_0
       1: iload_1
       2: invokevirtual #18                 // Method apply$mcII$sp:(I)I
       5: ireturn

  public int apply$mcII$sp(int);
    Code:
       0: iload_1
       1: bipush        7
       3: imul
       4: ireturn

  public final java.lang.Object apply(java.lang.Object);
    Code:
       0: aload_0
       1: aload_1
       2: invokestatic  #29                 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
       5: invokevirtual #31                 // Method apply:(I)I
       8: invokestatic  #35                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
      11: areturn

  public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5);
    Code:
       0: aload_0
       1: invokespecial #42                 // Method scala/runtime/AbstractFunction1$mcII$sp."<init>":()V
       4: return
}
```
## How was this patch tested?

Added new test suites to `DatasetPrimitiveSuite`.

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #17172 from kiszk/SPARK-19008.
---
 .../sql/catalyst/plans/logical/object.scala   |  38 +++++-
 .../apache/spark/sql/execution/objects.scala  |   6 +-
 .../apache/spark/sql/DatasetBenchmark.scala   | 122 +++++++++++++++++-
 .../spark/sql/DatasetPrimitiveSuite.scala     |  51 ++++++++
 4 files changed, 208 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 617239f56c..7f4462e583 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 object CatalystSerde {
   def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -211,13 +212,48 @@ case class TypedFilter(
   def typedCondition(input: Expression): Expression = {
     val (funcClass, methodName) = func match {
       case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
-      case _ => classOf[Any => Boolean] -> "apply"
+      case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType)
     }
     val funcObj = Literal.create(func, ObjectType(funcClass))
     Invoke(funcObj, methodName, BooleanType, input :: Nil)
   }
 }
 
+object FunctionUtils {
+  private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = {
+    dt match {
+      case BooleanType if isOutput => Some("Z")
+      case IntegerType => Some("I")
+      case LongType => Some("J")
+      case FloatType => Some("F")
+      case DoubleType => Some("D")
+      case _ => None
+    }
+  }
+
+  def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = {
+    // load "scala.Function1" using Java API to avoid requirements of type parameters
+    Utils.classForName("scala.Function1") -> {
+      // if a pair of an argument and return types is one of specific types
+      // whose specialized method (apply$mc..$sp) is generated by scalac,
+      // Catalyst generated a direct method call to the specialized method.
+      // The followings are references for this specialization:
+      //   http://www.scala-lang.org/api/2.12.0/scala/Function1.html
+      //   https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
+      //     SpecializeTypes.scala
+      //   http://www.cakesolutions.net/teamblogs/scala-dissection-functions
+      //   http://axel22.github.io/2013/11/03/specialization-quirks.html
+      val inputType = getMethodType(inputDT, false)
+      val outputType = getMethodType(outputDT, true)
+      if (inputType.isDefined && outputType.isDefined) {
+        s"apply$$mc${outputType.get}${inputType.get}$$sp"
+      } else {
+        "apply"
+      }
+    }
+  }
+}
+
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumns {
   def apply[T : Encoder, U : Encoder](
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 199ba5ce69..fdd1bcc94b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
 import org.apache.spark.sql.execution.streaming.KeyedStateImpl
-import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 
 /**
@@ -219,7 +221,7 @@ case class MapElementsExec(
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     val (funcClass, methodName) = func match {
       case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
-      case _ => classOf[Any => Any] -> "apply"
+      case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
     }
     val funcObj = Literal.create(func, ObjectType(funcClass))
     val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 66d94d6016..1a0672b887 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -31,6 +31,49 @@ object DatasetBenchmark {
 
   case class Data(l: Long, s: String)
 
+  def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+    import spark.implicits._
+
+    val rdd = spark.sparkContext.range(0, numRows)
+    val ds = spark.range(0, numRows)
+    val df = ds.toDF("l")
+    val func = (l: Long) => l + 1
+
+    val benchmark = new Benchmark("back-to-back map long", numRows)
+
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
+      var i = 0
+      while (i < numChains) {
+        res = res.map(func)
+        i += 1
+      }
+      res.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("DataFrame") { iter =>
+      var res = df
+      var i = 0
+      while (i < numChains) {
+        res = res.select($"l" + 1 as "l")
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("Dataset") { iter =>
+      var res = ds.as[Long]
+      var i = 0
+      while (i < numChains) {
+        res = res.map(func)
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark
+  }
+
   def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
     import spark.implicits._
 
@@ -72,6 +115,49 @@ object DatasetBenchmark {
     benchmark
   }
 
+  def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+    import spark.implicits._
+
+    val rdd = spark.sparkContext.range(1, numRows)
+    val ds = spark.range(1, numRows)
+    val df = ds.toDF("l")
+    val func = (l: Long) => l % 2L == 0L
+
+    val benchmark = new Benchmark("back-to-back filter Long", numRows)
+
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
+      var i = 0
+      while (i < numChains) {
+        res = res.filter(func)
+        i += 1
+      }
+      res.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("DataFrame") { iter =>
+      var res = df
+      var i = 0
+      while (i < numChains) {
+        res = res.filter($"l" % 2L === 0L)
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("Dataset") { iter =>
+      var res = ds.as[Long]
+      var i = 0
+      while (i < numChains) {
+        res = res.filter(func)
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark
+  }
+
   def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
     import spark.implicits._
 
@@ -165,9 +251,22 @@ object DatasetBenchmark {
     val numRows = 100000000
     val numChains = 10
 
-    val benchmark = backToBackMap(spark, numRows, numChains)
-    val benchmark2 = backToBackFilter(spark, numRows, numChains)
-    val benchmark3 = aggregate(spark, numRows)
+    val benchmark0 = backToBackMapLong(spark, numRows, numChains)
+    val benchmark1 = backToBackMap(spark, numRows, numChains)
+    val benchmark2 = backToBackFilterLong(spark, numRows, numChains)
+    val benchmark3 = backToBackFilter(spark, numRows, numChains)
+    val benchmark4 = aggregate(spark, numRows)
+
+    /*
+    OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
+    Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
+    back-to-back map long:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    RDD                                           1883 / 1892         53.1          18.8       1.0X
+    DataFrame                                      502 /  642        199.1           5.0       3.7X
+    Dataset                                        657 /  784        152.2           6.6       2.9X
+    */
+    benchmark0.run()
 
     /*
     OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -178,7 +277,18 @@ object DatasetBenchmark {
     DataFrame                                     2647 / 3116         37.8          26.5       1.3X
     Dataset                                       4781 / 5155         20.9          47.8       0.7X
     */
-    benchmark.run()
+    benchmark1.run()
+
+    /*
+    OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic
+    Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
+    back-to-back filter Long:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    RDD                                            846 / 1120        118.1           8.5       1.0X
+    DataFrame                                      270 /  329        370.9           2.7       3.1X
+    Dataset                                        545 /  789        183.5           5.4       1.6X
+    */
+    benchmark2.run()
 
     /*
     OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -189,7 +299,7 @@ object DatasetBenchmark {
     DataFrame                                       59 /   72       1695.4           0.6      22.8X
     Dataset                                       2777 / 2805         36.0          27.8       0.5X
     */
-    benchmark2.run()
+    benchmark3.run()
 
     /*
     Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
@@ -201,6 +311,6 @@ object DatasetBenchmark {
     Dataset sum using Aggregator                  4656 / 4758         21.5          46.6       0.4X
     Dataset complex Aggregator                    6636 / 7039         15.1          66.4       0.3X
     */
-    benchmark3.run()
+    benchmark4.run()
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 6b50cb3e48..82b707537e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -62,6 +62,40 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       2, 3, 4)
   }
 
+  test("mapPrimitive") {
+    val dsInt = Seq(1, 2, 3).toDS()
+    checkDataset(dsInt.map(_ > 1), false, true, true)
+    checkDataset(dsInt.map(_ + 1), 2, 3, 4)
+    checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
+    checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+    checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+    val dsLong = Seq(1L, 2L, 3L).toDS()
+    checkDataset(dsLong.map(_ > 1), false, true, true)
+    checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
+    checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
+    checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+    checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+    val dsFloat = Seq(1F, 2F, 3F).toDS()
+    checkDataset(dsFloat.map(_ > 1), false, true, true)
+    checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
+    checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
+    checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+    checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+    val dsDouble = Seq(1D, 2D, 3D).toDS()
+    checkDataset(dsDouble.map(_ > 1), false, true, true)
+    checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
+    checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
+      8589934593L, 8589934594L, 8589934595L)
+    checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
+    checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+    val dsBoolean = Seq(true, false).toDS()
+    checkDataset(dsBoolean.map(e => !e), false, true)
+  }
+
   test("filter") {
     val ds = Seq(1, 2, 3, 4).toDS()
     checkDataset(
@@ -69,6 +103,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       2, 4)
   }
 
+  test("filterPrimitive") {
+    val dsInt = Seq(1, 2, 3).toDS()
+    checkDataset(dsInt.filter(_ > 1), 2, 3)
+
+    val dsLong = Seq(1L, 2L, 3L).toDS()
+    checkDataset(dsLong.filter(_ > 1), 2L, 3L)
+
+    val dsFloat = Seq(1F, 2F, 3F).toDS()
+    checkDataset(dsFloat.filter(_ > 1), 2F, 3F)
+
+    val dsDouble = Seq(1D, 2D, 3D).toDS()
+    checkDataset(dsDouble.filter(_ > 1), 2D, 3D)
+
+    val dsBoolean = Seq(true, false).toDS()
+    checkDataset(dsBoolean.filter(e => !e), false)
+  }
+
   test("foreach") {
     val ds = Seq(1, 2, 3).toDS()
     val acc = sparkContext.longAccumulator
-- 
GitLab