diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 944739bcd207837b57ebdb14a9a899a44f3e20bb..edc7ca6f5146f49edb96421c7b2281b413e37970 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1362,8 +1362,8 @@ class DataFrame(object): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. - Values to_replace and value should contain either all numerics, all booleans, - or all strings. When replacing, the new value will be cast + Values to_replace and value must have the same type and can only be numerics, booleans, + or strings. Value can have None. When replacing, the new value will be cast to the type of the existing column. For numeric replacements all values to be replaced should have unique floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`) @@ -1373,8 +1373,8 @@ class DataFrame(object): Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping between a value and a replacement. - :param value: int, long, float, string, or list. - The replacement value must be an int, long, float, or string. If `value` is a + :param value: bool, int, long, float, string, list or None. + The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. If `value` is a scalar and `to_replace` is a sequence, then `value` is used as a replacement for each item in `to_replace`. @@ -1393,6 +1393,16 @@ class DataFrame(object): |null| null| null| +----+------+-----+ + >>> df4.na.replace('Alice', None).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1425,12 +1435,13 @@ class DataFrame(object): valid_types = (bool, float, int, long, basestring, list, tuple) if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict. " + "to_replace should be a bool, float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + if not isinstance(value, valid_types) and value is not None \ + and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " - "a float, int, long, string, list, or tuple. " + "a bool, float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): @@ -1446,21 +1457,21 @@ class DataFrame(object): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)): - value = [value for _ in range(len(to_replace))] - if isinstance(to_replace, dict): rep_dict = to_replace if value is not None: warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: + if isinstance(value, (float, int, long, basestring)) or value is None: + value = [value for _ in range(len(to_replace))] rep_dict = dict(zip(to_replace, value)) if isinstance(subset, basestring): subset = [subset] - # Verify we were not passed in mixed type generics." - if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + # Verify we were not passed in mixed type generics. + if not any(all_of_type(rep_dict.keys()) + and all_of_type(x for x in rep_dict.values() if x is not None) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cfd9c558ff67eddd59f5edd684d02eddbb63d620..cf2c473a1645c1066028e2c9479d90af23899b38 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1964,6 +1964,21 @@ class SQLTests(ReusedPySparkTestCase): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) + # replace list while value is not given (default to None) + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + self.assertTupleEqual(row, (None, 10, 80.0)) + + # replace string with None and then drop None rows + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() + self.assertEqual(row.count(), 0) + + # replace with number and None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() + self.assertTupleEqual(row, (u'Alice', 20, None)) + # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 871fff71e5538744e317a2d167c77a296029d1bc..e068df3586f0611107d019b4debdd8165d772088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -260,9 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -277,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -288,8 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -301,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -312,10 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * If `col` is "*", - * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -328,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -343,8 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles , strings or booleans. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -354,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -366,14 +367,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = replacement.head._2 match { - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + // Convert the NumericType in replacement map to DoubleType, + // while leaving StringType, BooleanType and null untouched. + val replacementMap: Map[_, _] = replacement.map { + case (k, v: String) => (k, v) + case (k, v: Boolean) => (k, v) + case (k: String, null) => (k, null) + case (k: Boolean, null) => (k, null) + case (k, null) => (convertToDouble(k), null) + case (k, v) => (convertToDouble(k), convertToDouble(v)) } - // targetColumnType is either DoubleType or StringType or BooleanType + // targetColumnType is either DoubleType, StringType or BooleanType, + // depending on the type of first key in replacement map. + // Only fields of targetColumnType will perform replacement. val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType case _: jl.Boolean => BooleanType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 47c9ba5847a4ff97d8d1a17907abb9f28237b582..e6983b6be555a6611f3cdb57fc31228e2a80635e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -262,4 +262,47 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) } + + test("replace with null") { + val input = Seq[(String, java.lang.Double, java.lang.Boolean)]( + ("Bob", 176.5, true), + ("Alice", 164.3, false), + ("David", null, true) + ).toDF("name", "height", "married") + + // Replace String with String and null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> null + )), + Row("Bravo", 176.5, true) :: + Row(null, 164.3, false) :: + Row("David", null, true) :: Nil) + + // Replace Double with null + checkAnswer( + input.na.replace("height", Map[Any, Any]( + 164.3 -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", null, false) :: + Row("David", null, true) :: Nil) + + // Replace Boolean with null + checkAnswer( + input.na.replace("*", Map[Any, Any]( + false -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", 164.3, null) :: + Row("David", null, true) :: Nil) + + // Replace String with null and then drop rows containing null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> null + )).na.drop("name" :: Nil).select("name"), + Row("Alice") :: Row("David") :: Nil) + } }