Skip to content
Snippets Groups Projects
Commit 8a634e9b authored by Nick Pritchard's avatar Nick Pritchard Committed by Xiangrui Meng
Browse files

[SPARK-10573] [ML] IndexToString output schema should be StringType

Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata.

Author: Nick Pritchard <nicholas.pritchard@falkonry.com>

Closes #8751 from pnpritchard/SPARK-10573.
parent ce6f3f16
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
/**
......@@ -229,8 +229,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}
......
......@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
......@@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}
test("IndexToString.transformSchema (SPARK-10573)") {
val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
val inSchema = StructType(Seq(StructField("input", DoubleType)))
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === StringType)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment