Skip to content
Snippets Groups Projects
Commit 5e492e9d authored by Eric Liang's avatar Eric Liang Committed by Xiangrui Meng
Browse files

[SPARK-12346][ML] Missing attribute names in GLM for vector-type features

Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names.

cc mengxr

Author: Eric Liang <ekl@databricks.com>

Closes #10323 from ericl/spark-12346.
parent 44fcf992
No related branches found
No related tags found
No related merge requests found
......@@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String)
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.map { attr =>
group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
} else {
attr
attr.withName(c + "_" + i)
}
}
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.fill(numAttrs)(NumericAttribute.defaultAttr)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
......
......@@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(attrs === expectedAttrs)
}
test("vector attribute generation") {
val formula = new RFormula().setFormula("id ~ vec")
val original = sqlContext.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec_0"), Some(1)),
new NumericAttribute(Some("vec_1"), Some(2))))
assert(attrs === expectedAttrs)
}
test("vector attribute generation with unnamed input attrs") {
val formula = new RFormula().setFormula("id ~ vec2")
val base = sqlContext.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val metadata = new AttributeGroup(
"vec2",
Array[Attribute](
NumericAttribute.defaultAttr,
NumericAttribute.defaultAttr)).toMetadata
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec2_0"), Some(1)),
new NumericAttribute(Some("vec2_1"), Some(2))))
assert(attrs === expectedAttrs)
}
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame(
......
......@@ -111,8 +111,8 @@ class VectorAssemblerSuite
assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
val userSalaryOut = features.getAttr(4)
assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
}
test("read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment