diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 82% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 7487aa1c12639da08a2614f808d265ffe113e776..0ace453ee9280bebc2cc0eaba285427d2490941f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,13 +31,13 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ -import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} -import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ private[kinesis] class KinesisTestUtils extends Logging { @@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -66,14 +66,12 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } - private lazy val kinesisProducer: KinesisProducer = { - val conf = new KinesisProducerConfiguration() - .setRecordMaxBufferedTime(1000) - .setMaxConnections(1) - .setRegion(regionName) - .setMetricsLevel("none") - - new KinesisProducer(conf) + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } } def streamName: String = { @@ -104,41 +102,8 @@ private[kinesis] class KinesisTestUtils extends Logging { */ def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) - if (aggregate) { - val future = kinesisProducer.addUserRecord(streamName, str, data) - val kinesisCallBack = new FutureCallback[UserRecordResult]() { - override def onFailure(t: Throwable): Unit = {} // do nothing - - override def onSuccess(result: UserRecordResult): Unit = { - val shardId = result.getShardId - val seqNumber = result.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - - Futures.addCallback(future, kinesisCallBack) - kinesisProducer.flushSync() // make sure we send all data before returning the map - } else { - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(data) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -264,3 +229,32 @@ private[kinesis] object KinesisTestUtils { } } } + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..fdb270eaad8c99b2f4e5eef1282f03ebbd73d6a6 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 52c61dfb1c02398620198e7c06f2ec70c9aed252..d85b4cda8ce986badd06cd2a66088ae408e296d2 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index dee30444d8cc670a74f4178f5600f61974530c86..78cec021b78c1af7924bbf80cb570f6a534c780e 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun sc = new SparkContext(conf) runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d50c6b8d4a4285a276d5b09d193dab00f98c40a1..a2bfd79e1abcd16f0c0d2c87604356a5f9e39721 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1458,7 +1458,6 @@ class KinesisStreamTests(PySparkStreamingTestCase): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") - @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write(