diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 027b5bbfab8d686042245af28f6473355b5451e1..c14feea91ed7d7608b2bee1950296b9517827da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.streaming -import java.io.IOException +import java.io.{InputStream, IOException, OutputStream} import java.nio.charset.StandardCharsets.UTF_8 +import scala.io.{Source => IOSource} import scala.reflect.ClassTag import org.apache.hadoop.fs.{Path, PathFilter} @@ -93,20 +94,25 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( } } - override def serialize(logData: Array[T]): Array[Byte] = { - (metadataLogVersion +: logData.map(serializeData)).mkString("\n").getBytes(UTF_8) + override def serialize(logData: Array[T], out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(metadataLogVersion.getBytes(UTF_8)) + logData.foreach { data => + out.write('\n') + out.write(serializeData(data).getBytes(UTF_8)) + } } - override def deserialize(bytes: Array[Byte]): Array[T] = { - val lines = new String(bytes, UTF_8).split("\n") - if (lines.length == 0) { + override def deserialize(in: InputStream): Array[T] = { + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines(0) + val version = lines.next() if (version != metadataLogVersion) { throw new IllegalStateException(s"Unknown log version: ${version}") } - lines.slice(1, lines.length).map(deserializeData) + lines.map(deserializeData).toArray } override def add(batchId: Long, logs: Array[T]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 39a0f3341389c06e54cd13bce4ec23496f931ce7..c7235320fd6bde62d32616d3abfc47be7149c897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.{FileNotFoundException, IOException} -import java.nio.ByteBuffer +import java.io.{FileNotFoundException, InputStream, IOException, OutputStream} import java.util.{ConcurrentModificationException, EnumSet, UUID} import scala.reflect.ClassTag @@ -29,7 +28,6 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SparkSession import org.apache.spark.util.UninterruptibleThread @@ -88,12 +86,16 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) } } - protected def serialize(metadata: T): Array[Byte] = { - JavaUtils.bufferToArray(serializer.serialize(metadata)) + protected def serialize(metadata: T, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + val outStream = serializer.serializeStream(out) + outStream.writeObject(metadata) } - protected def deserialize(bytes: Array[Byte]): T = { - serializer.deserialize[T](ByteBuffer.wrap(bytes)) + protected def deserialize(in: InputStream): T = { + // called inside a try-finally where the underlying stream is closed in the caller + val inStream = serializer.deserializeStream(in) + inStream.readObject[T]() } /** @@ -114,7 +116,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // Only write metadata when the batch has not yet been written Thread.currentThread match { case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, serialize(metadata)) } + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } case _ => throw new IllegalStateException( "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") @@ -129,7 +131,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + private def writeBatch(batchId: Long, metadata: T, writer: (T, OutputStream) => Unit): Unit = { // Use nextId to create a temp file var nextId = 0 while (true) { @@ -137,9 +139,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) try { val output = fileManager.create(tempPath) try { - output.write(bytes) + writer(metadata, output) } finally { - output.close() + IOUtils.closeQuietly(output) } try { // Try to commit the batch @@ -193,10 +195,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) if (fileManager.exists(batchMetadataFile)) { val input = fileManager.open(batchMetadataFile) try { - val bytes = IOUtils.toByteArray(input) - Some(deserialize(bytes)) + Some(deserialize(input)) } finally { - input.close() + IOUtils.closeQuietly(input) } } else { logDebug(s"Unable to find batch $batchMetadataFile") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 41a8cc2400dff139f73d6b1feffa6958ad7f78da..e1bc674a280713891cb017008ee0a4f957b6e277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.charset.StandardCharsets.UTF_8 import org.apache.spark.SparkFunSuite @@ -133,9 +134,12 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin // scalastyle:on - assert(expected === new String(sinkLog.serialize(logs), UTF_8)) - - assert(VERSION === new String(sinkLog.serialize(Array()), UTF_8)) + val baos = new ByteArrayOutputStream() + sinkLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + baos.reset() + sinkLog.serialize(Array(), baos) + assert(VERSION === baos.toString(UTF_8.name())) } } @@ -174,9 +178,9 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { blockSize = 30000L, action = FileStreamSinkLog.ADD_ACTION)) - assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) + assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(VERSION.getBytes(UTF_8))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) } }