Skip to content
Snippets Groups Projects
Commit 4be360d4 authored by BenFradet's avatar BenFradet Committed by Xiangrui Meng
Browse files

[SPARK-11902][ML] Unhandled case in VectorAssembler#transform

There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT.

So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType".

This PR aims to fix this, throwing a SparkException when dealing with an unknown column type.

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #9885 from BenFradet/SPARK-11902.
parent d9cf9c21
No related branches found
No related tags found
No related merge requests found
...@@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) ...@@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String)
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.fill(numAttrs)(NumericAttribute.defaultAttr) Array.fill(numAttrs)(NumericAttribute.defaultAttr)
} }
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
} }
} }
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
......
...@@ -69,6 +69,17 @@ class VectorAssemblerSuite ...@@ -69,6 +69,17 @@ class VectorAssemblerSuite
} }
} }
test("transform should throw an exception in case of unsupported type") {
val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c"))
.setOutputCol("features")
val thrown = intercept[SparkException] {
assembler.transform(df)
}
assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
}
test("ML attributes") { test("ML attributes") {
val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment