Skip to content
Snippets Groups Projects
Commit 96a4d1d0 authored by Prashant Sharma's avatar Prashant Sharma Committed by Shixiong Zhu
Browse files

[SPARK-19968][SS] Use a cached instance of `KafkaProducer` instead of creating one every batch.

## What changes were proposed in this pull request?

In summary, cost of recreating a KafkaProducer for writing every batch is high as it starts a lot threads and make connections and then closes them. A KafkaProducer instance is promised to be thread safe in Kafka docs. Reuse of KafkaProducer instance while writing via multiple threads is encouraged.

Furthermore, I have performance improvement of 10x in latency, with this patch.

### These are times that addBatch took in ms. Without applying this patch
![with-out_patch](https://cloud.githubusercontent.com/assets/992952/23994612/a9de4a42-0a6b-11e7-9d5b-7ae18775bee4.png)
### These are times that addBatch took in ms. After applying this patch
![with_patch](https://cloud.githubusercontent.com/assets/992952/23994616/ad8c11ec-0a6b-11e7-8634-2266ebb5033f.png)

## How was this patch tested?
Running distributed benchmarks comparing runs with this patch and without it.
Added relevant unit tests.

Author: Prashant Sharma <prashsh1@in.ibm.com>

Closes #17308 from ScrapCodes/cached-kafka-producer.
parent 1c7db00c
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}
import com.google.common.cache._
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.apache.kafka.clients.producer.KafkaProducer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
private[kafka010] object CachedKafkaProducer extends Logging {
private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
private lazy val cacheExpireTimeout: Long =
SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m")
private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
override def load(config: Seq[(String, Object)]): Producer = {
val configMap = config.map(x => x._1 -> x._2).toMap.asJava
createKafkaProducer(configMap)
}
}
private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() {
override def onRemoval(
notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = {
val paramsSeq: Seq[(String, Object)] = notification.getKey
val producer: Producer = notification.getValue
logDebug(
s"Evicting kafka producer $producer params: $paramsSeq, due to ${notification.getCause}")
close(paramsSeq, producer)
}
}
private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS)
.removalListener(removalListener)
.build[Seq[(String, Object)], Producer](cacheLoader)
private def createKafkaProducer(producerConfiguration: ju.Map[String, Object]): Producer = {
val kafkaProducer: Producer = new Producer(producerConfiguration)
logDebug(s"Created a new instance of KafkaProducer for $producerConfiguration.")
kafkaProducer
}
/**
* Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't
* exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep
* one instance per specified kafkaParams.
*/
private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = {
val paramsSeq: Seq[(String, Object)] = paramsToSeq(kafkaParams)
try {
guavaCache.get(paramsSeq)
} catch {
case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError)
if e.getCause != null =>
throw e.getCause
}
}
private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = {
val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1)
paramsSeq
}
/** For explicitly closing kafka producer */
private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
val paramsSeq = paramsToSeq(kafkaParams)
guavaCache.invalidate(paramsSeq)
}
/** Auto close on cache evict */
private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = {
try {
logInfo(s"Closing the KafkaProducer with params: ${paramsSeq.mkString("\n")}.")
producer.close()
} catch {
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
}
}
private def clear(): Unit = {
logInfo("Cleaning up guava cache.")
guavaCache.invalidateAll()
}
// Intended for testing purpose only.
private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap()
}
...@@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String ...@@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String
* and not use wrong broker addresses. * and not use wrong broker addresses.
*/ */
private[kafka010] class KafkaSource( private[kafka010] class KafkaSource(
sqlContext: SQLContext, sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader, kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String], sourceOptions: Map[String, String],
metadataPath: String, metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) failOnDataLoss: Boolean)
extends Source with Logging { extends Source with Logging {
private val sc = sqlContext.sparkContext private val sc = sqlContext.sparkContext
......
...@@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010 ...@@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010
import java.{util => ju} import java.{util => ju}
import org.apache.kafka.clients.producer.{KafkaProducer, _} import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
...@@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask( ...@@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask(
* Writes key value data out to topics. * Writes key value data out to topics.
*/ */
def execute(iterator: Iterator[InternalRow]): Unit = { def execute(iterator: Iterator[InternalRow]): Unit = {
producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) { while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next() val currentRow = iterator.next()
val projectedRow = projection(currentRow) val projectedRow = projection(currentRow)
...@@ -68,10 +67,10 @@ private[kafka010] class KafkaWriteTask( ...@@ -68,10 +67,10 @@ private[kafka010] class KafkaWriteTask(
} }
def close(): Unit = { def close(): Unit = {
checkForErrors()
if (producer != null) { if (producer != null) {
checkForErrors producer.flush()
producer.close() checkForErrors()
checkForErrors
producer = null producer = null
} }
} }
...@@ -88,7 +87,7 @@ private[kafka010] class KafkaWriteTask( ...@@ -88,7 +87,7 @@ private[kafka010] class KafkaWriteTask(
case t => case t =>
throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"must be a ${StringType}") "must be a StringType")
} }
val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
.getOrElse(Literal(null, BinaryType)) .getOrElse(Literal(null, BinaryType))
...@@ -100,7 +99,7 @@ private[kafka010] class KafkaWriteTask( ...@@ -100,7 +99,7 @@ private[kafka010] class KafkaWriteTask(
} }
val valueExpression = inputSchema val valueExpression = inputSchema
.find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
throw new IllegalStateException(s"Required attribute " + throw new IllegalStateException("Required attribute " +
s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
) )
valueExpression.dataType match { valueExpression.dataType match {
...@@ -114,7 +113,7 @@ private[kafka010] class KafkaWriteTask( ...@@ -114,7 +113,7 @@ private[kafka010] class KafkaWriteTask(
Cast(valueExpression, BinaryType)), inputSchema) Cast(valueExpression, BinaryType)), inputSchema)
} }
private def checkForErrors: Unit = { private def checkForErrors(): Unit = {
if (failedWrite != null) { if (failedWrite != null) {
throw failedWrite throw failedWrite
} }
......
...@@ -21,7 +21,6 @@ import java.{util => ju} ...@@ -21,7 +21,6 @@ import java.{util => ju}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType} import org.apache.spark.sql.types.{BinaryType, StringType}
...@@ -49,7 +48,7 @@ private[kafka010] object KafkaWriter extends Logging { ...@@ -49,7 +48,7 @@ private[kafka010] object KafkaWriter extends Logging {
topic: Option[String] = None): Unit = { topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output val schema = queryExecution.analyzed.output
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic == None) { if (topic.isEmpty) {
throw new AnalysisException(s"topic option required when no " + throw new AnalysisException(s"topic option required when no " +
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.ConcurrentMap
import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.PrivateMethodTester
import org.apache.spark.sql.test.SharedSQLContext
class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester {
type KP = KafkaProducer[Array[Byte], Array[Byte]]
protected override def beforeEach(): Unit = {
super.beforeEach()
val clear = PrivateMethod[Unit]('clear)
CachedKafkaProducer.invokePrivate(clear())
}
test("Should return the cached instance on calling getOrCreate with same params.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
// Here only host should be resolvable, it does not need a running instance of kafka server.
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val producer = CachedKafkaProducer.getOrCreate(kafkaParams)
val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams)
assert(producer == producer2)
val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 1)
}
test("Should close the correct kafka producer for the given kafkaPrams.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
kafkaParams.put("acks", "1")
val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
// With updated conf, a new producer instance should be created.
assert(producer != producer2)
val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 2)
CachedKafkaProducer.close(kafkaParams)
val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map2.size == 1)
import scala.collection.JavaConverters._
val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0)
assert(_producer == producer)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment