diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 9f8bb5d38fd6c9b0d3cba446bb1777a4c716acc0..27d32b5dca4315c8a446ae889dc6f464c79210da 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -4,3 +4,4 @@ org.apache.spark.sql.execution.datasources.json.JsonFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider +org.apache.spark.sql.execution.streaming.TextSocketSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index bef56160f6bb8eaeae004bfe9b8a60837e5d269c..9886ad0b410899f4846e6759b3700a52f4fc5993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -128,4 +128,6 @@ class FileStreamSource( override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) override def toString: String = s"FileStreamSource[$qualifiedBasePath]" + + override def stop() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 14450c2e2fd14d2ec3fd6712b08117e094d577f9..971147840d2fd6da38a34ac886001d7276710917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -39,4 +39,7 @@ trait Source { * same data for a particular `start` and `end` pair. */ def getBatch(start: Option[Offset], end: Offset): DataFrame + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 4aefd39b3646c334de2dd3e826531ea93073d333..bb42a11759b72eef9d74e867dcfdf4e353c27adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -399,6 +399,7 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } + uniqueSources.foreach(_.stop()) logInfo(s"Query $name was stopped") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 77fd043ef721971d94e16f7d8bffcdffcbbd9f27..e37f0c77795c331c72c0ec88aa8499afde0ebea2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -110,6 +110,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) sys.error("No data selected!") } } + + override def stop() {} } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala new file mode 100644 index 0000000000000000000000000000000000000000..d07d88dcdcc44114a4b95a0dc88c41e79ef9a426 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +object TextSocketSource { + val SCHEMA = StructType(StructField("value", StringType) :: Nil) +} + +/** + * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. + * This source will *not* work in production applications due to multiple reasons, including no + * support for fault recovery and keeping all of the text read in memory forever. + */ +class TextSocketSource(host: String, port: Int, sqlContext: SQLContext) + extends Source with Logging +{ + @GuardedBy("this") + private var socket: Socket = null + + @GuardedBy("this") + private var readThread: Thread = null + + @GuardedBy("this") + private var lines = new ArrayBuffer[String] + + initialize() + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + readThread = new Thread(s"TextSocketSource($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketSource.this.synchronized { + lines += line + } + } + } catch { + case e: IOException => + } + } + } + readThread.start() + } + + /** Returns the schema of the data from this source */ + override def schema: StructType = TextSocketSource.SCHEMA + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = synchronized { + if (lines.isEmpty) None else Some(LongOffset(lines.size - 1)) + } + + /** Returns the data that is between the offsets (`start`, `end`]. */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val startIdx = start.map(_.asInstanceOf[LongOffset].offset.toInt + 1).getOrElse(0) + val endIdx = end.asInstanceOf[LongOffset].offset.toInt + 1 + val data = synchronized { lines.slice(startIdx, endIdx) } + import sqlContext.implicits._ + data.toDF("value") + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + } +} + +class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { + /** Returns the name and schema of the source that can be used to continually read data. */ + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + logWarning("The socket source should not be used for production applications! " + + "It does not support recovery and stores state indefinitely.") + if (!parameters.contains("host")) { + throw new AnalysisException("Set a host to read from with option(\"host\", ...).") + } + if (!parameters.contains("port")) { + throw new AnalysisException("Set a port to read from with option(\"port\", ...).") + } + ("textSocket", TextSocketSource.SCHEMA) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val host = parameters("host") + val port = parameters("port").toInt + new TextSocketSource(host, port, sqlContext) + } + + /** String that represents the format that this data source provider uses. */ + override def shortName(): String = "socket" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ca577631854efb86ba76de2c1b0dff96f63dffe7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{IOException, OutputStreamWriter} +import java.net.ServerSocket +import java.util.concurrent.LinkedBlockingQueue + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (source != null) { + source.stop() + source = null + } + } + + private var serverThread: ServerThread = null + private var source: Source = null + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) + val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 + assert(schema === StructType(StructField("value", StringType) :: Nil)) + + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + assert(batch1.as[String].collect().toSeq === Seq("hello")) + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + assert(batch2.as[String].collect().toSeq === Seq("world")) + + val both = source.getBatch(None, offset2) + assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) + + // Try stopping the source to make sure this does not block forever. + source.stop() + source = null + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map()) + } + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) + } + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) + } + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + source = provider.createSource(sqlContext, "", None, "", parameters) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocket = new ServerSocket(0) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocket.getLocalPort + + override def run(): Unit = { + try { + val clientSocket = serverSocket.accept() + clientSocket.setTcpNoDelay(true) + val out = new OutputStreamWriter(clientSocket.getOutputStream) + while (true) { + val line = messageQueue.take() + out.write(line + "\n") + out.flush() + } + } catch { + case e: InterruptedException => + } finally { + serverSocket.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 786404a5895816b13742c6f2eafb63fd5ec8d9f6..b8e40e71bfce5873d797d09adb2df83dcec09204 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -282,6 +282,8 @@ class FakeDefaultSource extends StreamSourceProvider { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") } + + override def stop() {} } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 1aee1934c07910e0e0fb1bcaa86fc7b3ee0c42a3..943e7b761e6e958e635de2d80e8d52701ecb6bf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -84,6 +84,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { Seq[Int]().toDS().toDF() } + + override def stop() {} } }