diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 842dcb8c93dc2ca563349fb4f32761a6738c486d..f8e32d60a489a2a4a56d4d354009d288fd29f954 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.ByteArrayInputStream; +import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +37,7 @@ import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; @@ -56,6 +58,8 @@ import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.types.StructType; /** * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. @@ -69,7 +73,7 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo protected Path file; protected MessageType fileSchema; protected MessageType requestedSchema; - protected ReadSupport<T> readSupport; + protected StructType sparkSchema; /** * The total number of rows this RecordReader will eventually read. The sum of the @@ -125,20 +129,80 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo + " in range " + split.getStart() + ", " + split.getEnd()); } } - MessageType fileSchema = footer.getFileMetaData().getSchema(); + this.fileSchema = footer.getFileMetaData().getSchema(); Map<String, String> fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); - this.readSupport = getReadSupportInstance( + ReadSupport<T> readSupport = getReadSupportInstance( (Class<? extends ReadSupport<T>>) getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); this.requestedSchema = readContext.getRequestedSchema(); - this.fileSchema = fileSchema; + this.sparkSchema = new CatalystSchemaConverter(configuration).convert(requestedSchema); this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); } } + /** + * Returns the list of files at 'path' recursively. This skips files that are ignored normally + * by MapReduce. + */ + public static List<String> listDirectory(File path) throws IOException { + List<String> result = new ArrayList<String>(); + if (path.isDirectory()) { + for (File f: path.listFiles()) { + result.addAll(listDirectory(f)); + } + } else { + char c = path.getName().charAt(0); + if (c != '.' && c != '_') { + result.add(path.getAbsolutePath()); + } + } + return result; + } + + /** + * Initializes the reader to read the file at `path` with `columns` projected. If columns is + * null, all the columns are projected. + * + * This is exposed for testing to be able to create this reader without the rest of the Hadoop + * split machinery. It is not intended for general use and those not support all the + * configurations. + */ + protected void initialize(String path, List<String> columns) throws IOException { + Configuration config = new Configuration(); + config.set("spark.sql.parquet.binaryAsString", "false"); + config.set("spark.sql.parquet.int96AsTimestamp", "false"); + config.set("spark.sql.parquet.writeLegacyFormat", "false"); + + this.file = new Path(path); + long length = FileSystem.get(config).getFileStatus(this.file).getLen(); + ParquetMetadata footer = readFooter(config, file, range(0, length)); + + List<BlockMetaData> blocks = footer.getBlocks(); + this.fileSchema = footer.getFileMetaData().getSchema(); + + if (columns == null) { + this.requestedSchema = fileSchema; + } else { + Types.MessageTypeBuilder builder = Types.buildMessage(); + for (String s: columns) { + if (!fileSchema.containsField(s)) { + throw new IOException("Can only project existing columns. Unknown field: " + s + + " File schema:\n" + fileSchema); + } + builder.addFields(fileSchema.getType(s)); + } + this.requestedSchema = builder.named("spark_schema"); + } + this.sparkSchema = new CatalystSchemaConverter(config).convert(requestedSchema); + this.reader = new ParquetFileReader(config, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + @Override public Void getCurrentKey() throws IOException, InterruptedException { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 198bfb6d67aeec7d7a521a8d9327cfe1101443d0..47818c0939f2ae336237b67f0a7625d821d5428b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -121,14 +122,42 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException { super.initialize(inputSplit, taskAttemptContext); + initializeInternal(); + } + + /** + * Utility API that will read all the data in path. This circumvents the need to create Hadoop + * objects to use this class. `columns` can contain the list of columns to project. + */ + @Override + public void initialize(String path, List<String> columns) throws IOException { + super.initialize(path, columns); + initializeInternal(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. */ - if (requestedSchema.getFieldCount() == 0) { - // TODO: what does this mean? - throw new IOException("Empty request schema not supported."); - } int numVarLenFields = 0; originalTypes = new OriginalType[requestedSchema.getFieldCount()]; for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { @@ -182,25 +211,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas } } - @Override - public boolean nextKeyValue() throws IOException, InterruptedException { - if (batchIdx >= numBatched) { - if (!loadBatch()) return false; - } - ++batchIdx; - return true; - } - - @Override - public UnsafeRow getCurrentValue() throws IOException, InterruptedException { - return rows[batchIdx - 1]; - } - - @Override - public float getProgress() throws IOException, InterruptedException { - return (float) rowsReturned / totalRowCount; - } - /** * Decodes a batch of values into `rows`. This function is the hot path. */ @@ -253,10 +263,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas case INT96: throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); } - numBatched = num; - batchIdx = 0; } + numBatched = num; + batchIdx = 0; + // Update the total row lengths if the schema contained variable length. We did not maintain // this as we populated the columns. if (containsVarLenFields) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b0581e8b35510136d5280b9154cefbd4947ac7ec..7f82cce0a122d1be20e2e8e137adf20713027436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.column.{Encoding, ParquetProperties} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Utils import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -642,6 +645,77 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("UnsafeRowParquetRecordReader - direct path read") { + val data = (0 to 10).map(i => (i, ((i + 'a').toChar.toString))) + withTempPath { dir => + sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, null) + val result = mutable.ArrayBuffer.empty[(Int, String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getInt(0), row.getString(1)) + result += v + } + assert(data == result) + } finally { + reader.close() + } + } + + // Project just one column + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + result += row.getString(0) + } + assert(data.map(_._2) == result) + } finally { + reader.close() + } + } + + // Project columns in opposite order + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String, Int)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getString(0), row.getInt(1)) + result += v + } + assert(data.map { x => (x._2, x._1) } == result) + } finally { + reader.close() + } + } + + // Empty projection + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, List[String]().asJava) + var result = 0 + while (reader.nextKeyValue()) { + result += 1 + } + assert(result == data.length) + } finally { + reader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)