diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index e8b0dd61f34b17d924bc74911ba257a806e137d5..dc2a6f527558cde93d2b3abbaf163457ed39904f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -202,7 +202,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) } - val converter = RowEncoder(requiredSchema) + val converter = RowEncoder(dataSchema) val unsafeRowIterator = points.map { pt => val features = if (sparse) pt.features.toSparse else pt.features.toDense @@ -213,9 +213,12 @@ class DefaultSource extends FileFormat with DataSourceRegister { AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() // Appends partition values - val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute) + val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute) + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name) + } val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) unsafeRowIterator.map { dataRow => appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 0bd14978b2bbfe7dfe456cf96190afba2d64d1e0..e52fbd74a7b41c5aa37ce32113699ad5a55030b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -23,9 +23,9 @@ import java.nio.charset.StandardCharsets import com.google.common.io.Files import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.util.Utils @@ -104,4 +104,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { df.write.format("libsvm").save(path + "_2") } } + + test("select features from libsvm relation") { + val df = sqlContext.read.format("libsvm").load(path) + df.select("features").rdd.map { case Row(d: Vector) => d }.first + } }