Skip to content
Snippets Groups Projects
Commit 8098f158 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Cheng Lian
Browse files

[SPARK-14843][ML] Fix encoding error in LibSVMRelation

## What changes were proposed in this pull request?

We use `RowEncoder` in libsvm data source to serialize the label and features read from libsvm files. However, the schema passed in this encoder is not correct. As the result, we can't correctly select `features` column from the DataFrame. We should use full data schema instead of `requiredSchema` to serialize the data read in. Then do projection to select required columns later.

## How was this patch tested?
`LibSVMRelationSuite`.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #12611 from viirya/fix-libsvm.
parent c089c6f4
No related branches found
No related tags found
No related merge requests found
......@@ -202,7 +202,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}
val converter = RowEncoder(requiredSchema)
val converter = RowEncoder(dataSchema)
val unsafeRowIterator = points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
......@@ -213,9 +213,12 @@ class DefaultSource extends FileFormat with DataSourceRegister {
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
// Appends partition values
val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
val requiredOutput = fullOutput.filter { a =>
requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name)
}
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
unsafeRowIterator.map { dataRow =>
appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
......
......@@ -23,9 +23,9 @@ import java.nio.charset.StandardCharsets
import com.google.common.io.Files
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.util.Utils
......@@ -104,4 +104,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
df.write.format("libsvm").save(path + "_2")
}
}
test("select features from libsvm relation") {
val df = sqlContext.read.format("libsvm").load(path)
df.select("features").rdd.map { case Row(d: Vector) => d }.first
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment