diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 77a42c0873a6b2abc540df2c492e78cc024aa50a..f7be5f6b370ab29bdf6923e656a10b77ae7a8f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -198,7 +198,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: + * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -215,7 +216,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -232,7 +233,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ @@ -259,7 +261,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -282,8 +285,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. + * If `col` is "*", + * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -311,7 +316,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles , strings or booleans. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -333,15 +339,17 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] + // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] val replacementMap: Map[_, _] = replacement.head._2 match { case v: String => replacement + case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } } - // targetColumnType is either DoubleType or StringType + // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: jl.Boolean => BooleanType case _: String => StringType } @@ -367,7 +375,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // Check data type replaceValue match { - case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String => // This is good case _ => throw new IllegalArgumentException( s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") @@ -382,6 +390,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Double => fillCol[Double](f, v) case v: jl.Long => fillCol[Double](f, v.toDouble) case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } }.getOrElse(df.col(f.name)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 329ffb66083b1a3826bafeed9303ad28cbf2dfb1..e34875471f0930c784fcc960ee17d962648fa9f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,24 +141,26 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( - (null, null, null, null)).toDF("a", "b", "c", "d") + val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false )), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) // Test Java version checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false ).asJava), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) } test("replace") {