Skip to content
Snippets Groups Projects
Commit 4f17fddc authored by Matei Zaharia's avatar Matei Zaharia Committed by Reynold Xin
Browse files

[SPARK-16031] Add debug-only socket source in Structured Streaming

## What changes were proposed in this pull request?

This patch adds a text-based socket source similar to the one in Spark Streaming for debugging and tutorials. The source is clearly marked as debug-only so that users don't try to run it in production applications, because this type of source cannot provide HA without storing a lot of state in Spark.

## How was this patch tested?

Unit tests and manual tests in spark-shell.

Author: Matei Zaharia <matei@databricks.com>

Closes #13748 from mateiz/socket-source.
parent 5930d7a2
No related branches found
No related tags found
No related merge requests found
Showing
with 293 additions and 0 deletions
...@@ -4,3 +4,4 @@ org.apache.spark.sql.execution.datasources.json.JsonFileFormat ...@@ -4,3 +4,4 @@ org.apache.spark.sql.execution.datasources.json.JsonFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
...@@ -128,4 +128,6 @@ class FileStreamSource( ...@@ -128,4 +128,6 @@ class FileStreamSource(
override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1)
override def toString: String = s"FileStreamSource[$qualifiedBasePath]" override def toString: String = s"FileStreamSource[$qualifiedBasePath]"
override def stop() {}
} }
...@@ -39,4 +39,7 @@ trait Source { ...@@ -39,4 +39,7 @@ trait Source {
* same data for a particular `start` and `end` pair. * same data for a particular `start` and `end` pair.
*/ */
def getBatch(start: Option[Offset], end: Offset): DataFrame def getBatch(start: Option[Offset], end: Offset): DataFrame
/** Stop this source and free any resources it has allocated. */
def stop(): Unit
} }
...@@ -399,6 +399,7 @@ class StreamExecution( ...@@ -399,6 +399,7 @@ class StreamExecution(
microBatchThread.interrupt() microBatchThread.interrupt()
microBatchThread.join() microBatchThread.join()
} }
uniqueSources.foreach(_.stop())
logInfo(s"Query $name was stopped") logInfo(s"Query $name was stopped")
} }
......
...@@ -110,6 +110,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ...@@ -110,6 +110,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
sys.error("No data selected!") sys.error("No data selected!")
} }
} }
override def stop() {}
} }
/** /**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.io.{BufferedReader, InputStreamReader, IOException}
import java.net.Socket
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
object TextSocketSource {
val SCHEMA = StructType(StructField("value", StringType) :: Nil)
}
/**
* A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
* This source will *not* work in production applications due to multiple reasons, including no
* support for fault recovery and keeping all of the text read in memory forever.
*/
class TextSocketSource(host: String, port: Int, sqlContext: SQLContext)
extends Source with Logging
{
@GuardedBy("this")
private var socket: Socket = null
@GuardedBy("this")
private var readThread: Thread = null
@GuardedBy("this")
private var lines = new ArrayBuffer[String]
initialize()
private def initialize(): Unit = synchronized {
socket = new Socket(host, port)
val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
readThread = new Thread(s"TextSocketSource($host, $port)") {
setDaemon(true)
override def run(): Unit = {
try {
while (true) {
val line = reader.readLine()
if (line == null) {
// End of file reached
logWarning(s"Stream closed by $host:$port")
return
}
TextSocketSource.this.synchronized {
lines += line
}
}
} catch {
case e: IOException =>
}
}
}
readThread.start()
}
/** Returns the schema of the data from this source */
override def schema: StructType = TextSocketSource.SCHEMA
/** Returns the maximum available offset for this source. */
override def getOffset: Option[Offset] = synchronized {
if (lines.isEmpty) None else Some(LongOffset(lines.size - 1))
}
/** Returns the data that is between the offsets (`start`, `end`]. */
override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
val startIdx = start.map(_.asInstanceOf[LongOffset].offset.toInt + 1).getOrElse(0)
val endIdx = end.asInstanceOf[LongOffset].offset.toInt + 1
val data = synchronized { lines.slice(startIdx, endIdx) }
import sqlContext.implicits._
data.toDF("value")
}
/** Stop this source. */
override def stop(): Unit = synchronized {
if (socket != null) {
try {
// Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to
// stop the readThread is to close the socket.
socket.close()
} catch {
case e: IOException =>
}
socket = null
}
}
}
class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
/** Returns the name and schema of the source that can be used to continually read data. */
override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
logWarning("The socket source should not be used for production applications! " +
"It does not support recovery and stores state indefinitely.")
if (!parameters.contains("host")) {
throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
}
if (!parameters.contains("port")) {
throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
}
("textSocket", TextSocketSource.SCHEMA)
}
override def createSource(
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
val host = parameters("host")
val port = parameters("port").toInt
new TextSocketSource(host, port, sqlContext)
}
/** String that represents the format that this data source provider uses. */
override def shortName(): String = "socket"
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.io.{IOException, OutputStreamWriter}
import java.net.ServerSocket
import java.util.concurrent.LinkedBlockingQueue
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach {
import testImplicits._
override def afterEach() {
sqlContext.streams.active.foreach(_.stop())
if (serverThread != null) {
serverThread.interrupt()
serverThread.join()
serverThread = null
}
if (source != null) {
source.stop()
source = null
}
}
private var serverThread: ServerThread = null
private var source: Source = null
test("basic usage") {
serverThread = new ServerThread()
serverThread.start()
val provider = new TextSocketSourceProvider
val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString)
val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2
assert(schema === StructType(StructField("value", StringType) :: Nil))
source = provider.createSource(sqlContext, "", None, "", parameters)
failAfter(streamingTimeout) {
serverThread.enqueue("hello")
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
val offset1 = source.getOffset.get
val batch1 = source.getBatch(None, offset1)
assert(batch1.as[String].collect().toSeq === Seq("hello"))
serverThread.enqueue("world")
while (source.getOffset.get === offset1) {
Thread.sleep(10)
}
val offset2 = source.getOffset.get
val batch2 = source.getBatch(Some(offset1), offset2)
assert(batch2.as[String].collect().toSeq === Seq("world"))
val both = source.getBatch(None, offset2)
assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
// Try stopping the source to make sure this does not block forever.
source.stop()
source = null
}
}
test("params not given") {
val provider = new TextSocketSourceProvider
intercept[AnalysisException] {
provider.sourceSchema(sqlContext, None, "", Map())
}
intercept[AnalysisException] {
provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost"))
}
intercept[AnalysisException] {
provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234"))
}
}
test("no server up") {
val provider = new TextSocketSourceProvider
val parameters = Map("host" -> "localhost", "port" -> "0")
intercept[IOException] {
source = provider.createSource(sqlContext, "", None, "", parameters)
}
}
private class ServerThread extends Thread with Logging {
private val serverSocket = new ServerSocket(0)
private val messageQueue = new LinkedBlockingQueue[String]()
val port = serverSocket.getLocalPort
override def run(): Unit = {
try {
val clientSocket = serverSocket.accept()
clientSocket.setTcpNoDelay(true)
val out = new OutputStreamWriter(clientSocket.getOutputStream)
while (true) {
val line = messageQueue.take()
out.write(line + "\n")
out.flush()
}
} catch {
case e: InterruptedException =>
} finally {
serverSocket.close()
}
}
def enqueue(line: String): Unit = {
messageQueue.put(line)
}
}
}
...@@ -282,6 +282,8 @@ class FakeDefaultSource extends StreamSourceProvider { ...@@ -282,6 +282,8 @@ class FakeDefaultSource extends StreamSourceProvider {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
} }
override def stop() {}
} }
} }
} }
...@@ -84,6 +84,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { ...@@ -84,6 +84,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
Seq[Int]().toDS().toDF() Seq[Int]().toDS().toDF()
} }
override def stop() {}
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment