From b8ff2bc61c9835867f56afa1860ab5eb727c4a58 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Mon, 30 Mar 2015 20:47:10 -0700
Subject: [PATCH] [SPARK-6119][SQL] DataFrame support for missing data handling

This pull request adds variants of DataFrame.na.drop and DataFrame.na.fill to the Scala/Java API, and DataFrame.fillna and DataFrame.dropna to the Python API.

Author: Reynold Xin <rxin@databricks.com>

Closes #5274 from rxin/df-missing-value and squashes the following commits:

4ee1b98 [Reynold Xin] Improve error reporting in Python.
33a330c [Reynold Xin] Remove replace for now.
bc4fdbb [Reynold Xin] Added documentation for replace.
d56f5a5 [Reynold Xin] Added replace for Scala/Java.
2385d00 [Reynold Xin] Feedback from Xiangrui on "how".
914a374 [Reynold Xin] fill with map.
185c67e [Reynold Xin] Allow specifying column subsets in fill.
749eb47 [Reynold Xin] fillna
249b94e [Reynold Xin] Removing undefined functions.
6a73c68 [Reynold Xin] Missing file.
67d7003 [Reynold Xin] [SPARK-6119][SQL] DataFrame.na.drop (Scala/Java) and DataFrame.dropna (Python)
---
 python/pyspark/sql/dataframe.py               |  86 +++++++
 python/pyspark/sql/tests.py                   |  96 ++++++++
 .../catalyst/expressions/nullFunctions.scala  |  25 +-
 .../org/apache/spark/sql/DataFrame.scala      |  15 +-
 .../spark/sql/DataFrameNaFunctions.scala      | 228 ++++++++++++++++++
 .../org/apache/spark/sql/GroupedData.scala    |   5 +-
 .../org/apache/spark/sql/json/JsonRDD.scala   |   2 +-
 .../spark/sql/DataFrameNaFunctionsSuite.scala | 157 ++++++++++++
 8 files changed, 606 insertions(+), 8 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
 create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 23c0e63e77..4f174de811 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -690,6 +690,86 @@ class DataFrame(object):
         """
         return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
 
+    def dropna(self, how='any', thresh=None, subset=None):
+        """Returns a new :class:`DataFrame` omitting rows with null values.
+
+        :param how: 'any' or 'all'.
+            If 'any', drop a row if it contains any nulls.
+            If 'all', drop a row only if all its values are null.
+        :param thresh: int, default None
+            If specified, drop rows that have less than `thresh` non-null values.
+            This overwrites the `how` parameter.
+        :param subset: optional list of column names to consider.
+
+        >>> df4.dropna().show()
+        age height name
+        10  80     Alice
+        """
+        if how is not None and how not in ['any', 'all']:
+            raise ValueError("how ('" + how + "') should be 'any' or 'all'")
+
+        if subset is None:
+            subset = self.columns
+        elif isinstance(subset, basestring):
+            subset = [subset]
+        elif not isinstance(subset, (list, tuple)):
+            raise ValueError("subset should be a list or tuple of column names")
+
+        if thresh is None:
+            thresh = len(subset) if how == 'any' else 1
+
+        cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
+        cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
+        return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
+
+    def fillna(self, value, subset=None):
+        """Replace null values.
+
+        :param value: int, long, float, string, or dict.
+            Value to replace null values with.
+            If the value is a dict, then `subset` is ignored and `value` must be a mapping
+            from column name (string) to replacement value. The replacement value must be
+            an int, long, float, or string.
+        :param subset: optional list of column names to consider.
+            Columns specified in subset that do not have matching data type are ignored.
+            For example, if `value` is a string, and subset contains a non-string column,
+            then the non-string column is simply ignored.
+
+        >>> df4.fillna(50).show()
+        age height name
+        10  80     Alice
+        5   50     Bob
+        50  50     Tom
+        50  50     null
+
+        >>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
+        age height name
+        10  80     Alice
+        5   null   Bob
+        50  null   Tom
+        50  null   unknown
+        """
+        if not isinstance(value, (float, int, long, basestring, dict)):
+            raise ValueError("value should be a float, int, long, string, or dict")
+
+        if isinstance(value, (int, long)):
+            value = float(value)
+
+        if isinstance(value, dict):
+            value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
+            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+        elif subset is None:
+            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+        else:
+            if isinstance(subset, basestring):
+                subset = [subset]
+            elif not isinstance(subset, (list, tuple)):
+                raise ValueError("subset should be a list or tuple of column names")
+
+            cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
+            cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
+            return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
+
     def withColumn(self, colName, col):
         """ Return a new :class:`DataFrame` by adding a column.
 
@@ -1069,6 +1149,12 @@ def _test():
     globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
     globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
                                   Row(name='Bob', age=5, height=85)]).toDF()
+
+    globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
+                                  Row(name='Bob', age=5, height=None),
+                                  Row(name='Tom', age=None, height=None),
+                                  Row(name=None, age=None, height=None)]).toDF()
+
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.dataframe, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2720439416..258464b7f2 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -415,6 +415,102 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(_infer_type(2**61), LongType())
         self.assertEqual(_infer_type(2**71), LongType())
 
+    def test_dropna(self):
+        schema = StructType([
+            StructField("name", StringType(), True),
+            StructField("age", IntegerType(), True),
+            StructField("height", DoubleType(), True)])
+
+        # shouldn't drop a non-null row
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', 50, 80.1)], schema).dropna().count(),
+            1)
+
+        # dropping rows with a single null value
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, 80.1)], schema).dropna().count(),
+            0)
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
+            0)
+
+        # if how = 'all', only drop rows if all values are null
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
+            1)
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(None, None, None)], schema).dropna(how='all').count(),
+            0)
+
+        # how and subset
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
+            1)
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
+            0)
+
+        # threshold
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
+            1)
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, None)], schema).dropna(thresh=2).count(),
+            0)
+
+        # threshold and subset
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
+            1)
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
+            0)
+
+        # thresh should take precedence over how
+        self.assertEqual(self.sqlCtx.createDataFrame(
+            [(u'Alice', 50, None)], schema).dropna(
+                how='any', thresh=2, subset=['name', 'age']).count(),
+            1)
+
+    def test_fillna(self):
+        schema = StructType([
+            StructField("name", StringType(), True),
+            StructField("age", IntegerType(), True),
+            StructField("height", DoubleType(), True)])
+
+        # fillna shouldn't change non-null values
+        row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
+        self.assertEqual(row.age, 10)
+
+        # fillna with int
+        row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
+        self.assertEqual(row.age, 50)
+        self.assertEqual(row.height, 50.0)
+
+        # fillna with double
+        row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
+        self.assertEqual(row.age, 50)
+        self.assertEqual(row.height, 50.1)
+
+        # fillna with string
+        row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first()
+        self.assertEqual(row.name, u"hello")
+        self.assertEqual(row.age, None)
+
+        # fillna with subset specified for numeric cols
+        row = self.sqlCtx.createDataFrame(
+            [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
+        self.assertEqual(row.name, None)
+        self.assertEqual(row.age, 50)
+        self.assertEqual(row.height, None)
+
+        # fillna with subset specified for numeric cols
+        row = self.sqlCtx.createDataFrame(
+            [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
+        self.assertEqual(row.name, "haha")
+        self.assertEqual(row.age, None)
+        self.assertEqual(row.height, None)
+
 
 class HiveContextSQLTests(ReusedPySparkTestCase):
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index d1f3d4f4ee..f9161cf34f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -35,7 +35,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
 
   override def toString: String = s"Coalesce(${children.mkString(",")})"
 
-  def dataType: DataType = if (resolved) {
+  override def dataType: DataType = if (resolved) {
     children.head.dataType
   } else {
     val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
@@ -74,3 +74,26 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
     child.eval(input) != null
   }
 }
+
+/**
+ * A predicate that is evaluated to be true if there are at least `n` non-null values.
+ */
+case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
+  override def nullable: Boolean = false
+  override def foldable: Boolean = false
+  override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
+
+  private[this] val childrenArray = children.toArray
+
+  override def eval(input: Row): Boolean = {
+    var numNonNulls = 0
+    var i = 0
+    while (i < childrenArray.length && numNonNulls < n) {
+      if (childrenArray(i).eval(input) != null) {
+        numNonNulls += 1
+      }
+      i += 1
+    }
+    numNonNulls >= n
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 423ef3912b..5cd0a18ff6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -237,8 +237,8 @@ class DataFrame private[sql](
   def toDF(colNames: String*): DataFrame = {
     require(schema.size == colNames.size,
       "The number of columns doesn't match.\n" +
-        "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
-        "New column names: " + colNames.mkString(", "))
+        s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" +
+        s"New column names (${colNames.size}): " + colNames.mkString(", "))
 
     val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
       apply(oldName).as(newName)
@@ -319,6 +319,17 @@ class DataFrame private[sql](
    */
   def show(): Unit = show(20)
 
+  /**
+   * Returns a [[DataFrameNaFunctions]] for working with missing data.
+   * {{{
+   *   // Dropping rows containing any null values.
+   *   df.na.drop()
+   * }}}
+   *
+   * @group dfops
+   */
+  def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
+
   /**
    * Cartesian join with another [[DataFrame]].
    *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
new file mode 100644
index 0000000000..3a3dc70f72
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -0,0 +1,228 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import java.{lang => jl}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+
+/**
+ * Functionality for working with missing data in [[DataFrame]]s.
+ */
+final class DataFrameNaFunctions private[sql](df: DataFrame) {
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing any null values.
+   */
+  def drop(): DataFrame = drop("any", df.columns)
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing null values.
+   *
+   * If `how` is "any", then drop rows containing any null values.
+   * If `how` is "all", then drop rows only if every column is null for that row.
+   */
+  def drop(how: String): DataFrame = drop(how, df.columns)
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing any null values
+   * in the specified columns.
+   */
+  def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values
+   * in the specified columns.
+   */
+  def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing null values
+   * in the specified columns.
+   *
+   * If `how` is "any", then drop rows containing any null values in the specified columns.
+   * If `how` is "all", then drop rows only if every specified column is null for that row.
+   */
+  def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values
+   * in the specified columns.
+   *
+   * If `how` is "any", then drop rows containing any null values in the specified columns.
+   * If `how` is "all", then drop rows only if every specified column is null for that row.
+   */
+  def drop(how: String, cols: Seq[String]): DataFrame = {
+    how.toLowerCase match {
+      case "any" => drop(cols.size, cols)
+      case "all" => drop(1, cols)
+      case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
+    }
+  }
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values.
+   */
+  def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
+
+  /**
+   * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null
+   * values in the specified columns.
+   */
+  def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than
+   * `minNonNulls` non-null values in the specified columns.
+   */
+  def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
+    // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values.
+    val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
+    df.filter(Column(predicate))
+  }
+
+  /**
+   * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`.
+   */
+  def fill(value: Double): DataFrame = fill(value, df.columns)
+
+  /**
+   * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`.
+   */
+  def fill(value: String): DataFrame = fill(value, df.columns)
+
+  /**
+   * Returns a new [[DataFrame]] that replaces null values in specified numeric columns.
+   * If a specified column is not a numeric column, it is ignored.
+   */
+  def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified
+   * numeric columns. If a specified column is not a numeric column, it is ignored.
+   */
+  def fill(value: Double, cols: Seq[String]): DataFrame = {
+    val columnEquals = df.sqlContext.analyzer.resolver
+    val projections = df.schema.fields.map { f =>
+      // Only fill if the column is part of the cols list.
+      if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
+        fillCol[Double](f, value)
+      } else {
+        df.col(f.name)
+      }
+    }
+    df.select(projections : _*)
+  }
+
+  /**
+   * Returns a new [[DataFrame]] that replaces null values in specified string columns.
+   * If a specified column is not a string column, it is ignored.
+   */
+  def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in
+   * specified string columns. If a specified column is not a string column, it is ignored.
+   */
+  def fill(value: String, cols: Seq[String]): DataFrame = {
+    val columnEquals = df.sqlContext.analyzer.resolver
+    val projections = df.schema.fields.map { f =>
+      // Only fill if the column is part of the cols list.
+      if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
+        fillCol[String](f, value)
+      } else {
+        df.col(f.name)
+      }
+    }
+    df.select(projections : _*)
+  }
+
+  /**
+   * Returns a new [[DataFrame]] that replaces null values.
+   *
+   * The key of the map is the column name, and the value of the map is the replacement value.
+   * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`.
+   *
+   * For example, the following replaces null values in column "A" with string "unknown", and
+   * null values in column "B" with numeric value 1.0.
+   * {{{
+   *   import com.google.common.collect.ImmutableMap;
+   *   df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
+   * }}}
+   */
+  def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] that replaces null values.
+   *
+   * The key of the map is the column name, and the value of the map is the replacement value.
+   * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`.
+   *
+   * For example, the following replaces null values in column "A" with string "unknown", and
+   * null values in column "B" with numeric value 1.0.
+   * {{{
+   *   df.na.fill(Map(
+   *     "A" -> "unknown",
+   *     "B" -> 1.0
+   *   ))
+   * }}}
+   */
+  def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+
+  private def fill0(values: Seq[(String, Any)]): DataFrame = {
+    // Error handling
+    values.foreach { case (colName, replaceValue) =>
+      // Check column name exists
+      df.resolve(colName)
+
+      // Check data type
+      replaceValue match {
+        case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String =>
+          // This is good
+        case _ => throw new IllegalArgumentException(
+          s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).")
+      }
+    }
+
+    val columnEquals = df.sqlContext.analyzer.resolver
+    val projections = df.schema.fields.map { f =>
+      values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
+        v match {
+          case v: jl.Float => fillCol[Double](f, v.toDouble)
+          case v: jl.Double => fillCol[Double](f, v)
+          case v: jl.Long => fillCol[Double](f, v.toDouble)
+          case v: jl.Integer => fillCol[Double](f, v.toDouble)
+          case v: String => fillCol[String](f, v)
+        }
+      }.getOrElse(df.col(f.name))
+    }
+    df.select(projections : _*)
+  }
+
+  /**
+   * Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
+   */
+  private def fillCol[T](col: StructField, replacement: T): Column = {
+    coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 45a63ae26e..a5e6b638d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -127,10 +127,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
    * {{{
    *   // Selects the age of the oldest employee and the aggregate expense for each department
    *   import com.google.common.collect.ImmutableMap;
-   *   df.groupBy("department").agg(ImmutableMap.<String, String>builder()
-   *     .put("age", "max")
-   *     .put("expense", "sum")
-   *     .build());
+   *   df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum"));
    * }}}
    */
   def agg(exprs: java.util.Map[String, String]): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 2b0358c4e2..0b770f2251 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -49,7 +49,7 @@ private[sql] object JsonRDD extends Logging {
     val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
     val allKeys =
       if (schemaData.isEmpty()) {
-        Set.empty[(String,DataType)]
+        Set.empty[(String, DataType)]
       } else {
         parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _)
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
new file mode 100644
index 0000000000..0896f175c0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+
+class DataFrameNaFunctionsSuite extends QueryTest {
+
+  def createDF(): DataFrame = {
+    Seq[(String, java.lang.Integer, java.lang.Double)](
+      ("Bob", 16, 176.5),
+      ("Alice", null, 164.3),
+      ("David", 60, null),
+      ("Amy", null, null),
+      (null, null, null)).toDF("name", "age", "height")
+  }
+
+  test("drop") {
+    val input = createDF()
+    val rows = input.collect()
+
+    checkAnswer(
+      input.na.drop("name" :: Nil),
+      rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+
+    checkAnswer(
+      input.na.drop("age" :: Nil),
+      rows(0) :: rows(2) :: Nil)
+
+    checkAnswer(
+      input.na.drop("age" :: "height" :: Nil),
+      rows(0) :: Nil)
+
+    checkAnswer(
+      input.na.drop(),
+      rows(0))
+
+    // dropna on an a dataframe with no column should return an empty data frame.
+    val empty = input.sqlContext.emptyDataFrame.select()
+    assert(empty.na.drop().count() === 0L)
+
+    // Make sure the columns are properly named.
+    assert(input.na.drop().columns.toSeq === input.columns.toSeq)
+  }
+
+  test("drop with how") {
+    val input = createDF()
+    val rows = input.collect()
+
+    checkAnswer(
+      input.na.drop("all"),
+      rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+
+    checkAnswer(
+      input.na.drop("any"),
+      rows(0) :: Nil)
+
+    checkAnswer(
+      input.na.drop("any", Seq("age", "height")),
+      rows(0) :: Nil)
+
+    checkAnswer(
+      input.na.drop("all", Seq("age", "height")),
+      rows(0) :: rows(1) :: rows(2) :: Nil)
+  }
+
+  test("drop with threshold") {
+    val input = createDF()
+    val rows = input.collect()
+
+    checkAnswer(
+      input.na.drop(2, Seq("age", "height")),
+      rows(0) :: Nil)
+
+    checkAnswer(
+      input.na.drop(3, Seq("name", "age", "height")),
+      rows(0))
+
+    // Make sure the columns are properly named.
+    assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq)
+  }
+
+  test("fill") {
+    val input = createDF()
+
+    val fillNumeric = input.na.fill(50.6)
+    checkAnswer(
+      fillNumeric,
+      Row("Bob", 16, 176.5) ::
+        Row("Alice", 50, 164.3) ::
+        Row("David", 60, 50.6) ::
+        Row("Amy", 50, 50.6) ::
+        Row(null, 50, 50.6) :: Nil)
+
+    // Make sure the columns are properly named.
+    assert(fillNumeric.columns.toSeq === input.columns.toSeq)
+
+    // string
+    checkAnswer(
+      input.na.fill("unknown").select("name"),
+      Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil)
+    assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
+
+    // fill double with subset columns
+    checkAnswer(
+      input.na.fill(50.6, "age" :: Nil),
+      Row("Bob", 16, 176.5) ::
+        Row("Alice", 50, 164.3) ::
+        Row("David", 60, null) ::
+        Row("Amy", 50, null) ::
+        Row(null, 50, null) :: Nil)
+
+    // fill string with subset columns
+    checkAnswer(
+      Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
+      Row("test", null))
+  }
+
+  test("fill with map") {
+    val df = Seq[(String, String, java.lang.Long, java.lang.Double)](
+      (null, null, null, null)).toDF("a", "b", "c", "d")
+    checkAnswer(
+      df.na.fill(Map(
+        "a" -> "test",
+        "c" -> 1,
+        "d" -> 2.2
+      )),
+      Row("test", null, 1, 2.2))
+
+    // Test Java version
+    checkAnswer(
+      df.na.fill(mapAsJavaMap(Map(
+        "a" -> "test",
+        "c" -> 1,
+        "d" -> 2.2
+      ))),
+      Row("test", null, 1, 2.2))
+  }
+}
-- 
GitLab