Skip to content
Snippets Groups Projects
Unverified Commit b26f2c2c authored by root's avatar root Committed by DB Tsai
Browse files

[SPARK-18555][SQL] DataFrameNaFunctions.fill miss up original values in long integers


## What changes were proposed in this pull request?

   DataSet.na.fill(0) used on a DataSet which has a long value column, it will change the original long value.

   The reason is that the type of the function fill's param is Double, and the numeric columns are always cast to double(`fillCol[Double](f, value)`) .
```
  def fill(value: Double, cols: Seq[String]): DataFrame = {
    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
    val projections = df.schema.fields.map { f =>
      // Only fill if the column is part of the cols list.
      if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
        fillCol[Double](f, value)
      } else {
        df.col(f.name)
      }
    }
    df.select(projections : _*)
  }
```

 For example:
```
scala> val df = Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)).toDF("a", "b")
df: org.apache.spark.sql.DataFrame = [a: bigint, b: bigint]

scala> df.show
+-------------------+-------------------+
|                  a|                  b|
+-------------------+-------------------+
|                  1|                  2|
|                 -1|                 -2|
|9123146099426677101|9123146560113991650|
+-------------------+-------------------+

scala> df.na.fill(0).show
+-------------------+-------------------+
|                  a|                  b|
+-------------------+-------------------+
|                  1|                  2|
|                 -1|                 -2|
|9123146099426676736|9123146560113991680|
+-------------------+-------------------+
 ```

the original values changed [which is not we expected result]:
```
 9123146099426677101 -> 9123146099426676736
 9123146560113991650 -> 9123146560113991680
```

## How was this patch tested?

unit test added.

Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>

Closes #15994 from windpiger/nafillMissupOriginalValue.

(cherry picked from commit 508de38c)
Signed-off-by: default avatarDB Tsai <dbtsai@dbtsai.com>
parent 489c1f35
No related branches found
No related tags found
No related merge requests found
...@@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
/** /**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* *
* @since 2.2.0
*/
def fill(value: Long): DataFrame = fill(value, df.columns)
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1 * @since 1.3.1
*/ */
def fill(value: Double): DataFrame = fill(value, df.columns) def fill(value: Double): DataFrame = fill(value, df.columns)
...@@ -139,6 +145,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -139,6 +145,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/ */
def fill(value: String): DataFrame = fill(value, df.columns) def fill(value: String): DataFrame = fill(value, df.columns)
/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored.
*
* @since 2.2.0
*/
def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
/** /**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored. * If a specified column is not a numeric column, it is ignored.
...@@ -147,24 +161,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -147,24 +161,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/ */
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 2.2.0
*/
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
/** /**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored. * numeric columns. If a specified column is not a numeric column, it is ignored.
* *
* @since 1.3.1 * @since 1.3.1
*/ */
def fill(value: Double, cols: Seq[String]): DataFrame = { def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
fillCol[Double](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
/** /**
* Returns a new `DataFrame` that replaces null values in specified string columns. * Returns a new `DataFrame` that replaces null values in specified string columns.
...@@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* *
* @since 1.3.1 * @since 1.3.1
*/ */
def fill(value: String, cols: Seq[String]): DataFrame = { def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
fillCol[String](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
/** /**
* Returns a new `DataFrame` that replaces null values. * Returns a new `DataFrame` that replaces null values.
...@@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* *
* @since 1.3.1 * @since 1.3.1
*/ */
def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq)
/** /**
* (Scala-specific) Returns a new `DataFrame` that replaces null values. * (Scala-specific) Returns a new `DataFrame` that replaces null values.
...@@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* *
* @since 1.3.1 * @since 1.3.1
*/ */
def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
/** /**
* Replaces values matching keys in `replacement` map with the corresponding values. * Replaces values matching keys in `replacement` map with the corresponding values.
...@@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.select(projections : _*) df.select(projections : _*)
} }
private def fill0(values: Seq[(String, Any)]): DataFrame = { private def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling // Error handling
values.foreach { case (colName, replaceValue) => values.foreach { case (colName, replaceValue) =>
// Check column name exists // Check column name exists
...@@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { ...@@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case v => throw new IllegalArgumentException( case v => throw new IllegalArgumentException(
s"Unsupported value type ${v.getClass.getName} ($v).") s"Unsupported value type ${v.getClass.getName} ($v).")
} }
/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric, string columns. If a specified column is not a numeric, string column,
* it is ignored.
*/
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
// the fill[T] which T is Long/Double,
// should apply on all the NumericType Column, for example:
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
// input.na.fill(3.1)
// the result is (3,164.3), not (null, 164.3)
val targetType = value match {
case _: Double | _: Long => NumericType
case _: String => StringType
case _ => throw new IllegalArgumentException(
s"Unsupported value type ${value.getClass.getName} ($value).")
}
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val typeMatches = (targetType, f.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
fillCol[T](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
} }
...@@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { ...@@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer( checkAnswer(
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
Row("test", null)) Row("test", null))
checkAnswer(
Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L))
.toDF("a", "b").na.fill(0),
Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil
)
checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(2.34),
Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil
)
checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(5),
Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil
)
} }
test("fill with map") { test("fill with map") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment