diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index b02ace786c66cc965179424a363fcf8fa697a288..feb133d44898a0a3b3157100a9577cd0ab39317c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -24,20 +24,16 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.types.DataType +import org.apache.spark.SparkException import org.apache.spark.util.SerializableJobConf private[hive] @@ -46,19 +42,12 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: SparkPlan, overwrite: Boolean, - ifNotExists: Boolean) extends UnaryNode with HiveInspectors { + ifNotExists: Boolean) extends UnaryNode { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @transient private lazy val catalog = sc.catalog - private def newSerializer(tableDesc: TableDesc): Serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) - serializer - } - def output: Seq[Attribute] = Seq.empty private def saveAsHiveFile( @@ -78,44 +67,10 @@ case class InsertIntoHiveTable( conf.value, SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writeToFile _) + sc.sparkContext.runJob(rdd, writerContainer.writeToFile _) writerContainer.commitJob() - // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - val fieldOIs = standardOI.getAllStructFieldRefs.asScala - .map(_.getFieldObjectInspector).toArray - val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray - val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} - val outputData = new Array[Any](fieldOIs.length) - - writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - - val proj = FromUnsafeProjection(child.schema) - iterator.foreach { row => - var i = 0 - val safeRow = proj(row) - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) - i += 1 - } - - writerContainer - .getLocalFileWriter(safeRow, table.schema) - .write(serializer.serialize(outputData, standardOI)) - } - - writerContainer.close() - } } /** @@ -194,11 +149,21 @@ case class InsertIntoHiveTable( val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) - new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + new SparkHiveDynamicPartitionWriterContainer( + jobConf, + fileSinkConf, + dynamicPartColNames, + child.output, + table) } else { - new SparkHiveWriterContainer(jobConf, fileSinkConf) + new SparkHiveWriterContainer( + jobConf, + fileSinkConf, + child.output, + table) } + @transient val outputClass = writerContainer.newSerializer(table.tableDesc).getSerializedClass saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) val outputPath = FileOutputFormat.getOutputPath(jobConf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 22182ba00986f3eb94e54bfb9d11f4165e9e5f33..e9e08dbf8386a9449a1b108bf7a1318c6ad2a4b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import java.text.NumberFormat import java.util.Date -import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.FileUtils @@ -28,14 +28,18 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableJobConf @@ -45,9 +49,13 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging with Serializable { + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + inputSchema: Seq[Attribute], + table: MetastoreRelation) + extends Logging + with HiveInspectors + with Serializable { private val now = new Date() private val tableDesc: TableDesc = fileSinkConf.getTableInfo @@ -93,14 +101,12 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { - writer - } - def close() { // Seems the boolean value passed into close does not matter. - writer.close(false) - commit() + if (writer != null) { + writer.close(false) + commit() + } } def commitJob() { @@ -123,6 +129,13 @@ private[hive] class SparkHiveWriterContainer( SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } + def abortTask(): Unit = { + if (committer != null) { + committer.abortTask(taskContext) + } + logError(s"Task attempt $taskContext aborted.") + } + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { jobID = jobId splitID = splitId @@ -140,6 +153,44 @@ private[hive] class SparkHiveWriterContainer( conf.value.setBoolean("mapred.task.is.map", true) conf.value.setInt("mapred.task.partition", splitID) } + + def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + protected def prepareForWrite() = { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val dataTypes = inputSchema.map(_.dataType) + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + val outputData = new Array[Any](fieldOIs.length) + (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) + } + + // this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + writer.write(serializer.serialize(outputData, standardOI)) + } + + close() + } } private[hive] object SparkHiveWriterContainer { @@ -163,25 +214,22 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc, - dynamicPartColNames: Array[String]) - extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + dynamicPartColNames: Array[String], + inputSchema: Seq[Attribute], + table: MetastoreRelation) + extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema, table) { import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) - @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ - override protected def initWriters(): Unit = { - // NOTE: This method is executed at the executor side. - // Actual writers are created for each dynamic partition on the fly. - writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + // do nothing } override def close(): Unit = { - writers.values.foreach(_.close(false)) - commit() + // do nothing } override def commitJob(): Unit = { @@ -198,33 +246,89 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: InternalRow, schema: StructType) - : FileSinkOperator.RecordWriter = { - def convertToHiveRawString(col: String, value: Any): String = { - val raw = String.valueOf(value) - schema(col).dataType match { - case DateType => DateTimeUtils.dateToString(raw.toInt) - case _: DecimalType => BigDecimal(raw).toString() - case _ => raw - } + // this function is executed on executor side + override def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + val partitionOutput = inputSchema.takeRight(dynamicPartColNames.length) + val dataOutput = inputSchema.take(fieldOIs.length) + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionOutput, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataOutput, inputSchema) + + val fun: AnyRef = (pathString: String) => FileUtils.escapePathName(pathString, defaultPartName) + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionOutput.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF(fun, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartName), escaped) + val partitionName = Literal(dynamicPartColNames(i) + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionOutput) + + // If anything below fails, we should abort the task. + try { + val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionOutput), + StructType.fromAttributes(dataOutput), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } - def newWriter(): FileSinkOperator.RecordWriter = { + logInfo(s"Sorting complete. Writing out partition files one at a time.") + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: FileSinkOperator.RecordWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close(false) + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) + } + + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (sortedIterator.getValue.isNullAt(i)) { + null + } else { + wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i))) + } + i += 1 + } + currentWriter.write(serializer.serialize(outputData, standardOI)) + } + } finally { + if (currentWriter != null) { + currentWriter.close(false) + } + } + commit() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = { + val partitionPath = getPartitionString(key).getString(0) val newFileSinkDesc = new FileSinkDesc( - fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getDirName + partitionPath, fileSinkConf.getTableInfo, fileSinkConf.getCompressed) newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) @@ -234,7 +338,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // to avoid write to the same file when `spark.speculation=true` val path = FileOutputFormat.getTaskOutputPath( conf.value, - dynamicPartPath.stripPrefix("/") + "/" + getOutputName) + partitionPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, @@ -244,7 +348,5 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( path, Reporter.NULL) } - - writers.getOrElseUpdate(dynamicPartPath, newWriter()) } }