diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index b35bbaf404cc5edddc319c242dfcdd66b7c62c14..06e159172ab51ad3b2c0fbaf9aa3d00e4453f93c 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -254,6 +254,16 @@ class TopicAndPartition(object): def _jTopicAndPartition(self, helper): return helper.createTopicAndPartition(self._topic, self._partition) + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self._topic == other._topic + and self._partition == other._partition) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2c908daa8b214c537471174ac0746ba9d8813227..f7fa481d50235b6b7c21246d7eaf8f9e9c4fadf4 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -898,6 +898,16 @@ class KafkaStreamTests(PySparkStreamingTestCase): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + def test_topic_and_partition_equality(self): + topic_and_partition_a = TopicAndPartition("foo", 0) + topic_and_partition_b = TopicAndPartition("foo", 0) + topic_and_partition_c = TopicAndPartition("bar", 0) + topic_and_partition_d = TopicAndPartition("foo", 1) + + self.assertEqual(topic_and_partition_a, topic_and_partition_b) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds