diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 0f1790bddcc3d765a26d95b17a0f4b1558c4349e..f31ebf1ec8da0f87368a27188501c18d6e15e085 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -82,8 +82,8 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, - val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider + val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, + val kinesisCreds: SparkAWSCredentials = DefaultCredentials ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -109,7 +109,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = kinesisCredsProvider.provider.getCredentials + val credentials = kinesisCreds.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index fbc6b99443ed70a3386f41346b0fddb58fece47b..8970ad2bafda0dbae4fdf41ad1336da2772107f1 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,24 +22,28 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider + val streamName: String, + val endpointUrl: String, + val regionName: String, + val initialPositionInStream: InitialPositionInStream, + val checkpointAppName: String, + val checkpointInterval: Duration, + val _storageLevel: StorageLevel, + val messageHandler: Record => T, + val kinesisCreds: SparkAWSCredentials, + val dynamoDBCreds: Option[SparkAWSCredentials], + val cloudWatchCreds: Option[SparkAWSCredentials] ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +65,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - kinesisCredsProvider = kinesisCredsProvider) + kinesisCreds = kinesisCreds) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,7 +75,238 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, - kinesisCredsProvider) + checkpointAppName, checkpointInterval, _storageLevel, messageHandler, + kinesisCreds, dynamoDBCreds, cloudWatchCreds) } } + +@InterfaceStability.Evolving +object KinesisInputDStream { + /** + * Builder for [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + // Required params + private var streamingContext: Option[StreamingContext] = None + private var streamName: Option[String] = None + private var checkpointAppName: Option[String] = None + + // Params with defaults + private var endpointUrl: Option[String] = None + private var regionName: Option[String] = None + private var initialPositionInStream: Option[InitialPositionInStream] = None + private var checkpointInterval: Option[Duration] = None + private var storageLevel: Option[StorageLevel] = None + private var kinesisCredsProvider: Option[SparkAWSCredentials] = None + private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None + private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param ssc [[StreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(ssc: StreamingContext): Builder = { + streamingContext = Option(ssc) + this + } + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param jssc [[JavaStreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(jssc: JavaStreamingContext): Builder = { + streamingContext = Option(jssc.ssc) + this + } + + /** + * Sets the name of the Kinesis stream that the DStream will read from. This is a required + * parameter. + * + * @param streamName Name of Kinesis stream that the DStream will read from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamName(streamName: String): Builder = { + this.streamName = Option(streamName) + this + } + + /** + * Sets the KCL application name to use when checkpointing state to DynamoDB. This is a + * required parameter. + * + * @param appName Value to use for the KCL app name (used when creating the DynamoDB checkpoint + * table and when writing metrics to CloudWatch) + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointAppName(appName: String): Builder = { + checkpointAppName = Option(appName) + this + } + + /** + * Sets the AWS Kinesis endpoint URL. Defaults to "https://kinesis.us-east-1.amazonaws.com" if + * no custom value is specified + * + * @param url Kinesis endpoint URL to use + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def endpointUrl(url: String): Builder = { + endpointUrl = Option(url) + this + } + + /** + * Sets the AWS region to construct clients for. Defaults to "us-east-1" if no custom value + * is specified. + * + * @param regionName Name of AWS region to use (e.g. "us-west-2") + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def regionName(regionName: String): Builder = { + this.regionName = Option(regionName) + this + } + + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * + * @param initialPosition InitialPositionInStream value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { + initialPositionInStream = Option(initialPosition) + this + } + + /** + * Sets how often the KCL application state is checkpointed to DynamoDB. Defaults to the Spark + * Streaming batch interval if no custom value is specified. + * + * @param interval [[Duration]] specifying how often the KCL state should be checkpointed to + * DynamoDB. + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointInterval(interval: Duration): Builder = { + checkpointInterval = Option(interval) + this + } + + /** + * Sets the storage level of the blocks for the DStream created. Defaults to + * [[StorageLevel.MEMORY_AND_DISK_2]] if no custom value is specified. + * + * @param storageLevel [[StorageLevel]] to use for the DStream data blocks + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def storageLevel(storageLevel: StorageLevel): Builder = { + this.storageLevel = Option(storageLevel) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS Kinesis + * endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified. + * + * @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication + */ + def kinesisCredentials(credentials: SparkAWSCredentials): Builder = { + kinesisCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS DynamoDB + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication + */ + def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = { + dynamoDBCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS CloudWatch + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication + */ + def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = { + cloudWatchCredsProvider = Option(credentials) + this + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided + * message handler. + * + * @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]] + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def buildWithMessageHandler[T: ClassTag]( + handler: Record => T): KinesisInputDStream[T] = { + val ssc = getRequiredParam(streamingContext, "streamingContext") + new KinesisInputDStream( + ssc, + getRequiredParam(streamName, "streamName"), + endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), + regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), + initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + getRequiredParam(checkpointAppName, "checkpointAppName"), + checkpointInterval.getOrElse(ssc.graph.batchDuration), + storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), + handler, + kinesisCredsProvider.getOrElse(DefaultCredentials), + dynamoDBCredsProvider, + cloudWatchCredsProvider) + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and using the + * default message handler, which returns [[Array[Byte]]]. + * + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def build(): KinesisInputDStream[Array[Byte]] = buildWithMessageHandler(defaultMessageHandler) + + private def getRequiredParam[T](param: Option[T], paramName: String): T = param.getOrElse { + throw new IllegalArgumentException(s"No value provided for required parameter $paramName") + } + } + + /** + * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + * + * @return [[KinesisInputDStream.Builder]] instance + */ + def builder: Builder = new Builder + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + + private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = + "https://kinesis.us-east-1.amazonaws.com" + private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" + private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = + InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 320728f4bb221b90ae4283ef57310a99255dcf72..1026d0fcb59bd3700c8085edf11e3e92c756b412 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -70,9 +70,14 @@ import org.apache.spark.util.Utils * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to - * generate the AWSCredentialsProvider instance used for KCL - * authorization. + * @param kinesisCreds SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize Kinesis API calls. + * @param cloudWatchCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize CloudWatch API + * calls. Will use kinesisCreds if value is None. + * @param dynamoDBCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls. + * Will use kinesisCreds if value is None. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -83,7 +88,9 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider) + kinesisCreds: SparkAWSCredentials, + dynamoDBCreds: Option[SparkAWSCredentials], + cloudWatchCreds: Option[SparkAWSCredentials]) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -140,10 +147,13 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) + val kinesisProvider = kinesisCreds.provider val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( checkpointAppName, streamName, - kinesisCredsProvider.provider, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2d777982e760c06de11d2f0203379382d5cd8629..1298463bfba1e6a80636935fc3507702bb81dba0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -58,6 +58,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -73,7 +74,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, DefaultCredentialsProvider) + cleanedHandler, DefaultCredentials, None, None) } } @@ -108,6 +109,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -123,12 +125,12 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -169,6 +171,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -187,16 +190,16 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = STSCredentialsProvider( + val kinesisCredsProvider = STSCredentials( stsRoleArn = stsAssumeRoleArn, stsSessionName = stsSessionName, stsExternalId = Option(stsExternalId), - longLivedCredsProvider = BasicCredentialsProvider( + longLivedCreds = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -227,6 +230,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -240,7 +244,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, DefaultCredentialsProvider) + KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -272,6 +276,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -284,12 +289,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, kinesisCredsProvider) + KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } @@ -323,6 +328,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -372,6 +378,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -431,6 +438,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -482,6 +490,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -493,7 +502,8 @@ object KinesisUtils { storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) + initialPositionInStream, checkpointInterval, storageLevel, + KinesisInputDStream.defaultMessageHandler(_)) } /** @@ -524,6 +534,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -537,7 +548,7 @@ object KinesisUtils { awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel, - defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } private def validateRegion(regionName: String): String = { @@ -545,14 +556,6 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } - - private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { - if (record == null) return null - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } } /** @@ -597,7 +600,7 @@ private class KinesisUtilsPythonHelper { validateAwsCreds(awsAccessKeyId, awsSecretKey) KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) } else { validateAwsCreds(awsAccessKeyId, awsSecretKey) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala deleted file mode 100644 index aa6fe12edf74e71b4abee29b3ed404f74db81daf..0000000000000000000000000000000000000000 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import scala.collection.JavaConverters._ - -import com.amazonaws.auth._ - -import org.apache.spark.internal.Logging - -/** - * Serializable interface providing a method executors can call to obtain an - * AWSCredentialsProvider instance for authenticating to AWS services. - */ -private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { - /** - * Return an AWSCredentialProvider instance that can be used by the Kinesis Client - * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). - */ - def provider: AWSCredentialsProvider -} - -/** Returns DefaultAWSCredentialsProviderChain for authentication. */ -private[kinesis] final case object DefaultCredentialsProvider - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain -} - -/** - * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using - * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain - * instance with the provided arguments (e.g. if they are null). - */ -private[kinesis] final case class BasicCredentialsProvider( - awsAccessKeyId: String, - awsSecretKey: String) extends SerializableCredentialsProvider with Logging { - - def provider: AWSCredentialsProvider = try { - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) - } catch { - case e: IllegalArgumentException => - logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + - "falling back to DefaultAWSCredentialsProviderChain.", e) - new DefaultAWSCredentialsProviderChain - } -} - -/** - * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM - * role in order to authenticate against resources in an external account. - */ -private[kinesis] final case class STSCredentialsProvider( - stsRoleArn: String, - stsSessionName: String, - stsExternalId: Option[String] = None, - longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = { - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) - .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) - stsExternalId match { - case Some(stsExternalId) => - builder.withExternalId(stsExternalId) - .build() - case None => - builder.build() - } - } -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..9facfe8ff2b0fcc3597a6678a7c5d8319a27334d --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SparkAWSCredentials extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentials( + awsAccessKeyId: String, + awsSecretKey: String) extends SparkAWSCredentials with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentials( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCreds: SparkAWSCredentials = DefaultCredentials) + extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCreds.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} + +@InterfaceStability.Evolving +object SparkAWSCredentials { + /** + * Builder for [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + private var basicCreds: Option[BasicCredentials] = None + private var stsCreds: Option[STSCredentials] = None + + // scalastyle:off + /** + * Use a basic AWS keypair for long-lived authorization. + * + * @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is + * enabled. Make sure that your checkpoint directory is secure. Prefer using the + * [[http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default default provider chain]] + * instead if possible. + * + * @param accessKeyId AWS access key ID + * @param secretKey AWS secret key + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + // scalastyle:on + def basicCredentials(accessKeyId: String, secretKey: String): Builder = { + basicCreds = Option(BasicCredentials( + awsAccessKeyId = accessKeyId, + awsSecretKey = secretKey)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String): Builder = { + stsCreds = Option(STSCredentials(stsRoleArn = roleArn, stsSessionName = sessionName)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). STS will validate the provided external ID with the one defined + * in the trust policy of the IAM role to be assumed (if one is present). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @param externalId External ID to validate against assumed IAM role's trust policy + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String, externalId: String): Builder = { + stsCreds = Option(STSCredentials( + stsRoleArn = roleArn, + stsSessionName = sessionName, + stsExternalId = Option(externalId))) + this + } + + /** + * Returns the appropriate instance of [[SparkAWSCredentials]] given the configured + * parameters. + * + * - The long-lived credentials will either be [[DefaultCredentials]] or [[BasicCredentials]] + * if they were provided. + * + * - If STS credentials were provided, the configured long-lived credentials will be added to + * them and the result will be returned. + * + * - The long-lived credentials will be returned otherwise. + * + * @return [[SparkAWSCredentials]] to use for configured parameters + */ + def build(): SparkAWSCredentials = + stsCreds.map(_.copy(longLivedCreds = longLivedCreds)).getOrElse(longLivedCreds) + + private def longLivedCreds: SparkAWSCredentials = basicCreds.getOrElse(DefaultCredentials) + } + + /** + * Creates a [[SparkAWSCredentials.Builder]] for constructing + * [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + * + * @return [[SparkAWSCredentials.Builder]] instance + */ + def builder: Builder = new Builder +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..7205f6e27266ca4af7292a1feb2c5257338fa97a --- /dev/null +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Seconds; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; + +public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { + /** + * Basic test to ensure that the KinesisDStream.Builder interface is accessible from Java. + */ + @Test + public void testJavaKinesisDStreamBuilder() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; + InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream<byte[]> kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1c130654f3f95d34d83cb4e7acfc22ec59ebac60 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.lang.IllegalArgumentException + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} + +class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach + with MockitoSugar { + import KinesisInputDStream._ + + private val ssc = new StreamingContext(conf, batchDuration) + private val streamName = "a-very-nice-kinesis-stream-name" + private val checkpointAppName = "a-very-nice-kcl-app-name" + private def baseBuilder = KinesisInputDStream.builder + private def builder = baseBuilder.streamingContext(ssc) + .streamName(streamName) + .checkpointAppName(checkpointAppName) + + override def afterAll(): Unit = { + ssc.stop() + } + + test("should raise an exception if the StreamingContext is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamName(streamName).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the stream name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the checkpoint app name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).streamName(streamName).build() + } + } + + test("should propagate required values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.context == ssc) + assert(dstream.streamName == streamName) + assert(dstream.checkpointAppName == checkpointAppName) + } + + test("should propagate default values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) + assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) + assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.checkpointInterval == batchDuration) + assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) + assert(dstream.kinesisCreds == DefaultCredentials) + assert(dstream.dynamoDBCreds == None) + assert(dstream.cloudWatchCreds == None) + } + + test("should propagate custom non-auth values to KinesisInputDStream") { + val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + val customRegion = "us-west-2" + val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customAppName = "a-very-nice-kinesis-app" + val customCheckpointInterval = Seconds(30) + val customStorageLevel = StorageLevel.MEMORY_ONLY + val customKinesisCreds = mock[SparkAWSCredentials] + val customDynamoDBCreds = mock[SparkAWSCredentials] + val customCloudWatchCreds = mock[SparkAWSCredentials] + + val dstream = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPositionInStream(customInitialPosition) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstream.endpointUrl == customEndpointUrl) + assert(dstream.regionName == customRegion) + assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.checkpointAppName == customAppName) + assert(dstream.checkpointInterval == customCheckpointInterval) + assert(dstream._storageLevel == customStorageLevel) + assert(dstream.kinesisCreds == customKinesisCreds) + assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index deb411d73e5889b02007dc71f0cd860fc9d507e7..3b14c8471e2057b8b04abdd776713ba346c53cca 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} -import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -62,28 +61,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of credential provider classes") { - Utils.deserialize[BasicCredentialsProvider]( - Utils.serialize(BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y"))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId")))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId"), - longLivedCredsProvider = BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y")))) - } - test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index afb55c84f81fe8a1568f7e66bd182842edac5fa5..ed7e35805026e1e6a9da4e548eaf1b3e888fb252 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -138,7 +138,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + assert(kinesisRDD.kinesisCreds === BasicCredentials( awsAccessKeyId = dummyAWSAccessKey, awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f579c2c3a67999bca41304f32dbf419542f6d430 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.util.Utils + +class SparkAWSCredentialsBuilderSuite extends TestSuiteBase { + private def builder = SparkAWSCredentials.builder + + private val basicCreds = BasicCredentials( + awsAccessKeyId = "a-very-nice-access-key", + awsSecretKey = "a-very-nice-secret-key") + + private val stsCreds = STSCredentials( + stsRoleArn = "a-very-nice-role-arn", + stsSessionName = "a-very-nice-secret-key", + stsExternalId = Option("a-very-nice-external-id"), + longLivedCreds = basicCreds) + + test("should build DefaultCredentials when given no params") { + assert(builder.build() == DefaultCredentials) + } + + test("should build BasicCredentials") { + assertResult(basicCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + } + + test("should build STSCredentials") { + // No external ID, default long-lived creds + assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .build() + } + // Default long-lived creds + assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + // No external ID, basic keypair for long-lived creds + assertResult(stsCreds.copy(stsExternalId = None)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Basic keypair for long-lived creds + assertResult(stsCreds) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Order shouldn't matter + assertResult(stsCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + } + + test("SparkAWSCredentials classes should be serializable") { + assertResult(basicCreds) { + Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds)) + } + assertResult(stsCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsCreds)) + } + // Will also test if DefaultCredentials can be serialized + val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials) + assertResult(stsDefaultCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds)) + } + } +}