From 1221849f91739454b8e495889cba7498ba8beea7 Mon Sep 17 00:00:00 2001
From: Joseph Batchik <josephbatchik@gmail.com>
Date: Wed, 29 Jul 2015 23:35:55 -0700
Subject: [PATCH] [SPARK-8005][SQL] Input file name

Users can now get the file name of the partition being read in. A thread local variable is in `SQLNewHadoopRDD` and is set when the partition is computed. `SQLNewHadoopRDD` is moved to core so that the catalyst package can reach it.

This supports:

`df.select(inputFileName())`

and

`sqlContext.sql("select input_file_name() from table")`

Author: Joseph Batchik <josephbatchik@gmail.com>

Closes #7743 from JDrit/input_file_name and squashes the following commits:

abb8609 [Joseph Batchik] fixed failing test and changed the default value to be an empty string
d2f323d [Joseph Batchik] updates per review
102061f [Joseph Batchik] updates per review
75313f5 [Joseph Batchik] small fixes
c7f7b5a [Joseph Batchik] addeding input file name to Spark SQL
---
 .../apache/spark/rdd}/SqlNewHadoopRDD.scala   | 34 +++++++++++--
 .../catalyst/analysis/FunctionRegistry.scala  |  3 +-
 .../catalyst/expressions/InputFileName.scala  | 49 +++++++++++++++++++
 .../expressions/SparkPartitionID.scala        |  2 +
 .../expressions/NondeterministicSuite.scala   |  4 ++
 .../org/apache/spark/sql/functions.scala      |  9 ++++
 .../spark/sql/parquet/ParquetRelation.scala   |  3 +-
 .../spark/sql/ColumnExpressionSuite.scala     | 17 ++++++-
 .../scala/org/apache/spark/sql/UDFSuite.scala | 17 ++++++-
 .../org/apache/spark/sql/hive/UDFSuite.scala  |  6 ---
 10 files changed, 128 insertions(+), 16 deletions(-)
 rename {sql/core/src/main/scala/org/apache/spark/sql/execution => core/src/main/scala/org/apache/spark/rdd}/SqlNewHadoopRDD.scala (91%)
 create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
similarity index 91%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala
rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
index 3d75b6a91d..35e44cb59c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
@@ -15,12 +15,13 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.rdd
 
 import java.text.SimpleDateFormat
 import java.util.Date
 
-import org.apache.spark.{Partition => SparkPartition, _}
+import scala.reflect.ClassTag
+
 import org.apache.hadoop.conf.{Configurable, Configuration}
 import org.apache.hadoop.io.Writable
 import org.apache.hadoop.mapreduce._
@@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.executor.DataReadMethod
 import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Partition => SparkPartition, _}
 import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
-import org.apache.spark.rdd.{HadoopRDD, RDD}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 
-import scala.reflect.ClassTag
 
 private[spark] class SqlNewHadoopPartition(
     rddId: Int,
@@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition(
  * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be
  * folded into core.
  */
-private[sql] class SqlNewHadoopRDD[K, V](
+private[spark] class SqlNewHadoopRDD[K, V](
     @transient sc : SparkContext,
     broadcastedConf: Broadcast[SerializableConfiguration],
     @transient initDriverSideJobFuncOpt: Option[Job => Unit],
@@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V](
       val inputMetrics = context.taskMetrics
         .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
 
+      // Sets the thread local variable for the file's name
+      split.serializableHadoopSplit.value match {
+        case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
+        case _ => SqlNewHadoopRDD.unsetInputFileName()
+      }
+
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
       val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V](
             reader.close()
             reader = null
 
+            SqlNewHadoopRDD.unsetInputFileName()
+
             if (bytesReadCallback.isDefined) {
               inputMetrics.updateBytesRead()
             } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
@@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V](
 }
 
 private[spark] object SqlNewHadoopRDD {
+
+  /**
+   * The thread variable for the name of the current file being read. This is used by
+   * the InputFileName function in Spark SQL.
+   */
+  private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
+    override protected def initialValue(): UTF8String = UTF8String.fromString("")
+  }
+
+  def getInputFileName(): UTF8String = inputFileName.get()
+
+  private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
+
+  private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
+
   /**
    * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
    * the given function rather than the index of the partition.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 372f80d4a8..378df4f57d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -230,7 +230,8 @@ object FunctionRegistry {
     expression[Sha1]("sha"),
     expression[Sha1]("sha1"),
     expression[Sha2]("sha2"),
-    expression[SparkPartitionID]("spark_partition_id")
+    expression[SparkPartitionID]("spark_partition_id"),
+    expression[InputFileName]("input_file_name")
   )
 
   val builtin: FunctionRegistry = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
new file mode 100644
index 0000000000..1e74f71695
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.rdd.SqlNewHadoopRDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]]
+ */
+case class InputFileName() extends LeafExpression with Nondeterministic {
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = StringType
+
+  override val prettyName = "INPUT_FILE_NAME"
+
+  override protected def initInternal(): Unit = {}
+
+  override protected def evalInternal(input: InternalRow): UTF8String = {
+    SqlNewHadoopRDD.getInputFileName()
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    ev.isNull = "false"
+    s"final ${ctx.javaType(dataType)} ${ev.primitive} = " +
+      "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();"
+  }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 3f6480bbf0..4b1772a2de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -34,6 +34,8 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm
 
   @transient private[this] var partitionId: Int = _
 
+  override val prettyName = "SPARK_PARTITION_ID"
+
   override protected def initInternal(): Unit = {
     partitionId = TaskContext.getPartitionId()
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
index 82894822ab..bf1c930c0b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
@@ -27,4 +27,8 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("SparkPartitionID") {
     checkEvaluation(SparkPartitionID(), 0)
   }
+
+  test("InputFileName") {
+    checkEvaluation(InputFileName(), "")
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4e68a88e7c..a2fece62f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -743,6 +743,15 @@ object functions {
    */
   def sparkPartitionId(): Column = SparkPartitionID()
 
+  /**
+   * The file name of the current Spark task
+   *
+   * Note that this is indeterministic becuase it depends on what is currently being read in.
+   *
+   * @group normal_funcs
+   */
+  def inputFileName(): Column = InputFileName()
+
   /**
    * Computes the square root of the specified float value.
    *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index cc6fa2b886..1a8176d8a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog}
 
 import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD}
 import org.apache.spark.rdd.RDD._
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD}
 import org.apache.spark.sql.execution.datasources.PartitionSpec
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.{DataType, StructType}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1f9f7118c3..5c11024108 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -22,13 +22,16 @@ import org.scalatest.Matchers._
 import org.apache.spark.sql.execution.Project
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.test.SQLTestUtils
 
-class ColumnExpressionSuite extends QueryTest {
+class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
   import org.apache.spark.sql.TestData._
 
   private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
   import ctx.implicits._
 
+  override def sqlContext(): SQLContext = ctx
+
   test("alias") {
     val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
     assert(df.select(df("a").as("b")).columns.head === "b")
@@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest {
     )
   }
 
+  test("InputFileName") {
+    withTempPath { dir =>
+      val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id")
+      data.write.parquet(dir.getCanonicalPath)
+      val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName())
+        .head.getString(0)
+      assert(answer.contains(dir.getCanonicalPath))
+
+      checkAnswer(data.select(inputFileName()).limit(1), Row(""))
+    }
+  }
+
   test("lift alias out of cast") {
     compareExpressions(
       col("1234").as("name").cast("int").expr,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index d9c8b380ef..183dc3407b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.test.SQLTestUtils
 
 case class FunctionResult(f1: String, f2: String)
 
-class UDFSuite extends QueryTest {
+class UDFSuite extends QueryTest with SQLTestUtils {
 
   private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
   import ctx.implicits._
 
+  override def sqlContext(): SQLContext = ctx
+
   test("built-in fixed arity expressions") {
     val df = ctx.emptyDataFrame
     df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
@@ -58,6 +61,18 @@ class UDFSuite extends QueryTest {
     ctx.dropTempTable("tmp_table")
   }
 
+  test("SPARK-8005 input_file_name") {
+    withTempPath { dir =>
+      val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
+      data.write.parquet(dir.getCanonicalPath)
+      ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
+      val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
+      assert(answer.contains(dir.getCanonicalPath))
+      assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
+      ctx.dropTempTable("test_table")
+    }
+  }
+
   test("error reporting for incorrect number of arguments") {
     val df = ctx.emptyDataFrame
     val e = intercept[AnalysisException] {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 37afc2142a..9b3ede43ee 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -34,10 +34,4 @@ class UDFSuite extends QueryTest {
     assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
     assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
   }
-
-  test("SPARK-8003 spark_partition_id") {
-    val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying")
-    ctx.registerDataFrameAsTable(df, "test_table")
-    checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0))
-  }
 }
-- 
GitLab