From 09a00510c4759ff87abb0b2fdf1630ddf36ca12c Mon Sep 17 00:00:00 2001
From: Lianhui Wang <lianhuiwang09@gmail.com>
Date: Thu, 19 May 2016 23:03:59 -0700
Subject: [PATCH] [SPARK-15335][SQL] Implement TRUNCATE TABLE Command

## What changes were proposed in this pull request?

Like TRUNCATE TABLE Command in Hive, TRUNCATE TABLE is also supported by Hive. See the link: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
Below is the related Hive JIRA: https://issues.apache.org/jira/browse/HIVE-446
This PR is to implement such a command for truncate table excluded column truncation(HIVE-4005).

## How was this patch tested?
Added a test case.

Author: Lianhui Wang <lianhuiwang09@gmail.com>

Closes #13170 from lianhuiwang/truncate.
---
 .../spark/sql/execution/SparkSqlParser.scala  | 19 +++++
 .../spark/sql/execution/command/tables.scala  | 53 +++++++++++++
 .../sql/hive/execution/HiveCommandSuite.scala | 79 +++++++++++++++++++
 3 files changed, 151 insertions(+)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 3045f3af36..8af6d07719 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -350,6 +350,25 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
     )
   }
 
+  /**
+   * Create a [[TruncateTable]] command.
+   *
+   * For example:
+   * {{{
+   *   TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)]
+   *   [COLUMNS (col1, col2)]
+   * }}}
+   */
+  override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) {
+    if (ctx.identifierList != null) {
+      throw operationNotAllowed("TRUNCATE TABLE ... COLUMNS", ctx)
+    }
+    TruncateTable(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)
+    )
+  }
+
   /**
    * Convert a table property list into a key-value map.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index a347274537..d13492e550 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -22,6 +22,9 @@ import java.net.URI
 import java.util.Date
 
 import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.fs.Path
 
 import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
@@ -270,6 +273,56 @@ case class LoadData(
   }
 }
 
+/**
+ * A command to truncate table.
+ *
+ * The syntax of this command is:
+ * {{{
+ *  TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)]
+ * }}}
+ */
+case class TruncateTable(
+    tableName: TableIdentifier,
+    partitionSpec: Option[TablePartitionSpec]) extends RunnableCommand {
+
+  override def run(sparkSession: SparkSession): Seq[Row] = {
+    val catalog = sparkSession.sessionState.catalog
+    if (!catalog.tableExists(tableName)) {
+      logError(s"table '$tableName' in TRUNCATE TABLE does not exist.")
+    } else if (catalog.isTemporaryTable(tableName)) {
+      logError(s"table '$tableName' in TRUNCATE TABLE is a temporary table.")
+    } else {
+      val locations = if (partitionSpec.isDefined) {
+        catalog.listPartitions(tableName, partitionSpec).map(_.storage.locationUri)
+      } else {
+        val table = catalog.getTableMetadata(tableName)
+        if (table.partitionColumnNames.nonEmpty) {
+          catalog.listPartitions(tableName).map(_.storage.locationUri)
+        } else {
+          Seq(table.storage.locationUri)
+        }
+      }
+      val hadoopConf = sparkSession.sessionState.newHadoopConf()
+      locations.foreach { location =>
+        if (location.isDefined) {
+          val path = new Path(location.get)
+          try {
+            val fs = path.getFileSystem(hadoopConf)
+            fs.delete(path, true)
+            fs.mkdirs(path)
+          } catch {
+            case NonFatal(e) =>
+              throw new AnalysisException(
+                s"Failed to truncate table '$tableName' when removing data of the path: $path " +
+                  s"because of ${e.toString}")
+          }
+        }
+      }
+    }
+    Seq.empty[Row]
+  }
+}
+
 /**
  * Command that looks like
  * {{{
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
index 8225bd69c1..df62ba08b8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode}
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 
@@ -269,6 +270,84 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
     }
   }
 
+  test("Truncate Table") {
+    withTable("non_part_table", "part_table") {
+      sql(
+        """
+          |CREATE TABLE non_part_table (employeeID INT, employeeName STRING)
+          |ROW FORMAT DELIMITED
+          |FIELDS TERMINATED BY '|'
+          |LINES TERMINATED BY '\n'
+        """.stripMargin)
+
+      val testData = hiveContext.getHiveFile("data/files/employee.dat").getCanonicalPath
+
+      sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE non_part_table""")
+      checkAnswer(
+        sql("SELECT * FROM non_part_table WHERE employeeID = 16"),
+        Row(16, "john") :: Nil)
+
+      val testResults = sql("SELECT * FROM non_part_table").collect()
+
+      intercept[ParseException] {
+        sql("TRUNCATE TABLE non_part_table COLUMNS (employeeID)")
+      }
+
+      sql("TRUNCATE TABLE non_part_table")
+      checkAnswer(sql("SELECT * FROM non_part_table"), Seq.empty[Row])
+
+      sql(
+        """
+          |CREATE TABLE part_table (employeeID INT, employeeName STRING)
+          |PARTITIONED BY (c STRING, d STRING)
+          |ROW FORMAT DELIMITED
+          |FIELDS TERMINATED BY '|'
+          |LINES TERMINATED BY '\n'
+        """.stripMargin)
+
+      sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="1", d="1")""")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '1'"),
+        testResults)
+
+      sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="1", d="2")""")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '2'"),
+        testResults)
+
+      sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="2", d="2")""")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '2' AND d = '2'"),
+        testResults)
+
+      intercept[ParseException] {
+        sql("TRUNCATE TABLE part_table PARTITION(c='1', d='1') COLUMNS (employeeID)")
+      }
+
+      sql("TRUNCATE TABLE part_table PARTITION(c='1', d='1')")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '1'"),
+        Seq.empty[Row])
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '2'"),
+        testResults)
+
+      intercept[ParseException] {
+        sql("TRUNCATE TABLE part_table PARTITION(c='1') COLUMNS (employeeID)")
+      }
+
+      sql("TRUNCATE TABLE part_table PARTITION(c='1')")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1'"),
+        Seq.empty[Row])
+
+      sql("TRUNCATE TABLE part_table")
+      checkAnswer(
+        sql("SELECT employeeID, employeeName FROM part_table"),
+        Seq.empty[Row])
+    }
+  }
+
   test("show columns") {
     checkAnswer(
       sql("SHOW COLUMNS IN parquet_tab3"),
-- 
GitLab