Skip to content
Snippets Groups Projects
Commit 9fe38aba authored by sethah's avatar sethah Committed by Nick Pentreath
Browse files

[SPARK-11108][ML] OneHotEncoder should support other numeric types

Adding support for other numeric types:

* Integer
* Short
* Long
* Float
* Decimal

Author: sethah <seth.hendrickson16@gmail.com>

Closes #9777 from sethah/SPARK-11108.
parent 9525c563
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
/**
* :: Experimental ::
......@@ -70,7 +70,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
val inputColName = $(inputCol)
val outputColName = $(outputCol)
SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
require(schema(inputColName).dataType.isInstanceOf[NumericType],
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
val inputFields = schema.fields
require(!inputFields.exists(_.name == outputColName),
s"Output column $outputColName already exists.")
......@@ -133,7 +134,9 @@ class OneHotEncoder(override val uid: String) extends Transformer
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
.aggregate(0.0)(
(m, x) => {
assert(x >=0.0 && x == x.toInt,
assert(x <= Int.MaxValue,
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
assert(x >= 0.0 && x == x.toInt,
s"Values from column $inputColName must be indices, but got $x.")
math.max(m, x)
},
......
......@@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
class OneHotEncoderSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -111,4 +112,32 @@ class OneHotEncoderSuite
.setDropLast(false)
testDefaultReadWrite(t)
}
test("OneHotEncoder with varying types") {
val df = stringIndexed()
val dfWithTypes = df
.withColumn("shortLabel", df("labelIndex").cast(ShortType))
.withColumn("longLabel", df("labelIndex").cast(LongType))
.withColumn("intLabel", df("labelIndex").cast(IntegerType))
.withColumn("floatLabel", df("labelIndex").cast(FloatType))
.withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
"floatLabel", "decimalLabel")
for (col <- cols) {
val encoder = new OneHotEncoder()
.setInputCol(col)
.setOutputCol("labelVec")
.setDropLast(false)
val encoded = encoder.transform(dfWithTypes)
val output = encoded.select("id", "labelVec").rdd.map { r =>
val vec = r.getAs[Vector](1)
(r.getInt(0), vec(0), vec(1), vec(2))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
assert(output === expected)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment