From 1221849f91739454b8e495889cba7498ba8beea7 Mon Sep 17 00:00:00 2001 From: Joseph Batchik <josephbatchik@gmail.com> Date: Wed, 29 Jul 2015 23:35:55 -0700 Subject: [PATCH] [SPARK-8005][SQL] Input file name Users can now get the file name of the partition being read in. A thread local variable is in `SQLNewHadoopRDD` and is set when the partition is computed. `SQLNewHadoopRDD` is moved to core so that the catalyst package can reach it. This supports: `df.select(inputFileName())` and `sqlContext.sql("select input_file_name() from table")` Author: Joseph Batchik <josephbatchik@gmail.com> Closes #7743 from JDrit/input_file_name and squashes the following commits: abb8609 [Joseph Batchik] fixed failing test and changed the default value to be an empty string d2f323d [Joseph Batchik] updates per review 102061f [Joseph Batchik] updates per review 75313f5 [Joseph Batchik] small fixes c7f7b5a [Joseph Batchik] addeding input file name to Spark SQL --- .../apache/spark/rdd}/SqlNewHadoopRDD.scala | 34 +++++++++++-- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../catalyst/expressions/InputFileName.scala | 49 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 2 + .../expressions/NondeterministicSuite.scala | 4 ++ .../org/apache/spark/sql/functions.scala | 9 ++++ .../spark/sql/parquet/ParquetRelation.scala | 3 +- .../spark/sql/ColumnExpressionSuite.scala | 17 ++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 17 ++++++- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 --- 10 files changed, 128 insertions(+), 16 deletions(-) rename {sql/core/src/main/scala/org/apache/spark/sql/execution => core/src/main/scala/org/apache/spark/rdd}/SqlNewHadoopRDD.scala (91%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 3d75b6a91d..35e44cb59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.spark.{Partition => SparkPartition, _} +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V]( reader.close() reader = null + SqlNewHadoopRDD.unsetInputFileName() + if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || @@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 372f80d4a8..378df4f57d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -230,7 +230,8 @@ object FunctionRegistry { expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), - expression[SparkPartitionID]("spark_partition_id") + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000000..1e74f71695 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 3f6480bbf0..4b1772a2de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -34,6 +34,8 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm @transient private[this] var partitionId: Int = _ + override val prettyName = "SPARK_PARTITION_ID" + override protected def initInternal(): Unit = { partitionId = TaskContext.getPartitionId() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 82894822ab..bf1c930c0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -27,4 +27,8 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4e68a88e7c..a2fece62f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -743,6 +743,15 @@ object functions { */ def sparkPartitionId(): Column = SparkPartitionID() + /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() + /** * Computes the square root of the specified float value. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cc6fa2b886..1a8176d8a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3..5c11024108 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,13 +22,16 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d9c8b380ef..183dc3407b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -58,6 +61,18 @@ class UDFSuite extends QueryTest { ctx.dropTempTable("tmp_table") } + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 37afc2142a..9b3ede43ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -34,10 +34,4 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } - - test("SPARK-8003 spark_partition_id") { - val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") - ctx.registerDataFrameAsTable(df, "test_table") - checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0)) - } } -- GitLab