Skip to content
Snippets Groups Projects
Commit b1ef6a60 authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column

Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits:

b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.
parent f8cbb0a4
No related branches found
No related tags found
No related merge requests found
......@@ -233,6 +233,7 @@ private object VectorIndexer {
* - Continuous features (columns) are left unchanged.
* This also appends metadata to the output column, marking features as Numeric (continuous),
* Nominal (categorical), or Binary (either continuous or categorical).
* Non-ML metadata is not carried over from the input to the output column.
*
* This maintains vector sparsity.
*
......@@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
// TODO: Check more carefully about whether this whole class will be included in a closure.
/** Per-vector transform function */
private val transformFunc: Vector => Vector = {
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps
val f: Vector => Vector = {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
val localNumFeatures = numFeatures
val f: Vector => Vector = { (v: Vector) =>
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length ${v.size}")
v match {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
}
tmpv
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
}
}
tmpv
}
}
f
}
......@@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
// For now, just check the first row of inputCol for vector length.
val firstRow = dataset.select(map(inputCol)).take(1)
if (firstRow.length != 0) {
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length $actualNumFeatures")
}
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
}
......@@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
s"VectorIndexerModel requires output column parameter: $outputCol")
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
// If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
......@@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
* Prepare the output column field, including per-feature metadata.
* @param schema Input schema
* @param map Parameter map (with this class' embedded parameter map folded in)
* @return Output column field
* @return Output column field. This field does not contain non-ML metadata.
*/
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
......@@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
partialFeatureAttributes
}
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
newAttributeGroup.toStructField()
}
}
......@@ -23,7 +23,6 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.util.TestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
......@@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[IllegalArgumentException] {
model.transform(densePoints2)
intercept[SparkException] {
model.transform(densePoints2).collect()
println("Did not throw error when fit, transform were called on vectors of different lengths")
}
intercept[SparkException] {
......@@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
// TODO: Once input features marked as categorical are handled correctly, check that here.
}
}
// Check that non-ML metadata are preserved.
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.ml.Transformer
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.MetadataBuilder
import org.scalatest.FunSuite
private[ml] object TestingUtils extends FunSuite {
/**
* Test whether unrelated metadata are preserved for this transformer.
* This attaches extra metadata to a column, transforms the column, and check to ensure the
* extra metadata have not changed.
* @param data Input dataset
* @param transformer Transformer to test
* @param inputCol Unique input column for Transformer. This must be the ONLY input column.
* @param outputCol Output column to test for metadata presence.
*/
def testPreserveMetadata(
data: DataFrame,
transformer: Transformer,
inputCol: String,
outputCol: String): Unit = {
// Create some fake metadata
val origMetadata = data.schema(inputCol).metadata
val metaKey = "__testPreserveMetadata__fake_key"
val metaValue = 12345
assert(!origMetadata.contains(metaKey),
s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
val newMetadata =
new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
// Add metadata to the inputCol
val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
// Transform, and ensure extra metadata was not affected
val transformed = transformer.transform(withMetadata)
val transMetadata = transformed.schema(outputCol).metadata
assert(transMetadata.contains(metaKey),
"Unit test with testPreserveMetadata failed; extra metadata key was not present.")
assert(transMetadata.getLong(metaKey) === metaValue,
"Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment