diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala index dd3f28e4197e349e71823b3c16fb6b2848c55c04..e10660793d1622807e114d6cd22bf37c45a8dfad 100644 --- a/core/src/main/scala/org/apache/spark/Partition.scala +++ b/core/src/main/scala/org/apache/spark/Partition.scala @@ -28,4 +28,6 @@ trait Partition extends Serializable { // A better default implementation of HashCode override def hashCode(): Int = index + + override def equals(other: Any): Boolean = super.equals(other) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bc1eb043610a9af5cd01f4233563fc03935aaf8..2381f54ee3f06a619bd28f5dc5e4569416e298df 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -58,10 +58,10 @@ private[spark] case class NarrowCoGroupSplitDep( * narrowDeps should always be equal to the number of parents. */ private[spark] class CoGroupPartition( - idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) + override val index: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { - override val index: Int = idx - override def hashCode(): Int = idx + override def hashCode(): Int = index + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b1e15572c03a3e2909eeed24f4d487cf24ebe9c..b22134af45b30c02a723c45e9ad53f00907b6abe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -53,14 +53,14 @@ import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownH /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) - override def hashCode(): Int = 41 * (41 + rddId) + idx + override def hashCode(): Int = 31 * (31 + rddId) + index - override val index: Int = idx + override def equals(other: Any): Boolean = super.equals(other) /** * Get any environment variables that should be added to the users environment when running pipes diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index a71c191b318eade2cc6d9441a03ff88613717d6f..ad7c2216a042f9b8a47cad6ef257a562e3a9a58e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -45,7 +45,10 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index + + override def hashCode(): Int = 31 * (31 + rddId) + index + + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 0abba15bec9f72a94d2299331372061069112fd0..b6366f3e68df98b679b84b67c687aca68815026f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -31,12 +31,13 @@ import org.apache.spark.util.Utils private[spark] class PartitionerAwareUnionRDDPartition( @transient val rdds: Seq[RDD[_]], - val idx: Int + override val index: Int ) extends Partition { - var parents = rdds.map(_.partitions(idx)).toArray + var parents = rdds.map(_.partitions(index)).toArray - override val index = idx - override def hashCode(): Int = idx + override def hashCode(): Int = index + + override def equals(other: Any): Boolean = super.equals(other) @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 800b42505de1012aa9eeb24f54ac0b26cf10aacb..29d5d74650cdb463d5af9f1b7ba6f1242e7a3e32 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -25,7 +25,10 @@ import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index: Int = idx - override def hashCode(): Int = idx + + override def hashCode(): Int = index + + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala index d8d818ceed45f48b79a3f65f4d945383aa1bff7c..838686923767ed038ec6eeb7e07e781655e63bb5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Arrays +import java.util.Objects import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -53,6 +54,9 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A parentPartitionMapping(parent.getPartition(key)) } + override def hashCode(): Int = + 31 * Objects.hashCode(parent) + Arrays.hashCode(partitionStartIndices) + override def equals(other: Any): Boolean = other match { case c: CoalescedPartitioner => c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices) @@ -66,6 +70,8 @@ private[spark] class CustomShuffledRDDPartition( extends Partition { override def hashCode(): Int = index + + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 27d063630be9deb2cfb0f91dde29d00b44a0e877..57a82312008e9134c7d28985cc27806f403aa0f7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -476,6 +476,9 @@ object KryoTest { class ClassWithNoArgConstructor { var x: Int = 0 + + override def hashCode(): Int = x + override def equals(other: Any): Boolean = other match { case c: ClassWithNoArgConstructor => x == c.x case _ => false @@ -483,6 +486,8 @@ object KryoTest { } class ClassWithoutNoArgConstructor(val x: Int) { + override def hashCode(): Int = x + override def equals(other: Any): Boolean = other match { case c: ClassWithoutNoArgConstructor => x == c.x case _ => false diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 932704c1a365950fceda98e4a75f33b5e9ccb5ba..4920b7ee8bfb4547cdee1c3bcdb87d66b4dd53a8 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -124,6 +124,8 @@ class ClosureCleanerSuite extends SparkFunSuite { // A non-serializable class we create in closures to make sure that we aren't // keeping references to unneeded variables from our outer closures. class NonSerializable(val id: Int = -1) { + override def hashCode(): Int = id + override def equals(other: Any): Boolean = { other match { case o: NonSerializable => id == o.id diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala index c787b5f066e00ede8baeeac92b3f9e105fb79017..ea22db35555dd987cf8cec22b29302fcd28c6b12 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala @@ -22,4 +22,8 @@ package org.apache.spark.util.collection */ case class FixedHashObject(v: Int, h: Int) extends Serializable { override def hashCode(): Int = h + override def equals(other: Any): Boolean = other match { + case that: FixedHashObject => v == that.v && h == that.h + case _ => false + } } diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index baa04fb0fd134d5a908205e7b4df2e94b173e7d4..8204b5af02cff205379e9be75b6f669fc2e73c0a 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -458,6 +458,8 @@ class SparseMatrix ( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def hashCode(): Int = toBreeze.hashCode() + override def equals(o: Any): Boolean = o match { case m: Matrix => toBreeze == m.toBreeze case _ => false diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index fd4ce9adb8427e9e6a71bddcf6cce6de342aae71..4275a22ae000a54cacf4db9d90455338371fa0c3 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -476,6 +476,8 @@ class DenseVector (val values: Array[Double]) extends Vector { } } + override def equals(other: Any): Boolean = super.equals(other) + override def hashCode(): Int = { var result: Int = 31 + size var i = 0 @@ -602,6 +604,8 @@ class SparseVector ( } } + override def equals(other: Any): Boolean = super.equals(other) + override def hashCode(): Int = { var result: Int = 31 + size val end = values.length diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 9d895b8faca7d60c4c3841088086e4c31a812a13..5d11ed0971dbded39dabc796650b82aa7fd1c737 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import java.util.Objects + import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} @@ -112,12 +114,15 @@ final class CategoricalSplit private[ml] ( } } - override def equals(o: Any): Boolean = { - o match { - case other: CategoricalSplit => featureIndex == other.featureIndex && - isLeft == other.isLeft && categories == other.categories - case _ => false - } + override def hashCode(): Int = { + val state = Seq(featureIndex, isLeft, categories) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + override def equals(o: Any): Boolean = o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false } override private[tree] def toOld: OldSplit = { @@ -181,6 +186,11 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr } } + override def hashCode(): Int = { + val state = Seq(featureIndex, threshold) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + override private[tree] def toOld: OldSplit = { OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index bb5d6d9d51d138d1ef65043a0663110246cf15d0..90fa4fbbc604db8ed81547861a1f7b4f2b588427 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -606,6 +606,8 @@ class SparseMatrix @Since("1.3.0") ( case _ => false } + override def hashCode(): Int = toBreeze.hashCode + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 5ec83e8d5c20675bf7f51caba21d7e27cf3ab659..6e3da6b701cb07491f1d7eec2a48d1f08d443035 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -628,6 +628,8 @@ class DenseVector @Since("1.0.0") ( } } + override def equals(other: Any): Boolean = super.equals(other) + override def hashCode(): Int = { var result: Int = 31 + size var i = 0 @@ -775,6 +777,8 @@ class SparseVector @Since("1.0.0") ( } } + override def equals(other: Any): Boolean = super.equals(other) + override def hashCode(): Int = { var result: Int = 31 + size val end = values.length diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c98a39dc0c83987a094cd01641029f72780c3721..27838167fdeec67ca87d9a1001e440768e0044eb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -113,6 +113,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.function.FlatMapGroupsFunction.call") ) ++ + Seq( + // [SPARK-6429] Implement hashCode and equals together + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals") + ) ++ Seq( // SPARK-4819 replace Guava Optional ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e39400e2d1840865217cdcca1bec3636553bf113..270104f85b83813b2ca52f93a14c9e8321d6cf09 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -262,7 +262,7 @@ This file is divided into 3 sections: </check> <!-- Should turn this on, but we have a few places that need to be fixed first --> - <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="false"></check> + <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"></check> <!-- ================================================================================ --> <!-- rules we don't want --> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 8bdf9b29c9641a11dd52c700feabc33d4875495a..b77f93373e78d9aa9ddfd8aeb39ddd6fc25b9f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -60,6 +60,8 @@ object AttributeSet { class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { + override def hashCode: Int = baseSet.hashCode() + /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { case otherSet: AttributeSet => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 607c7c877cc14c9c12fff37363f62909421da4c3..d0ad7a05a0c3756fe6ab5b30d24b6a2aa8bca9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -35,7 +35,8 @@ class EquivalentExpressions { case other: Expr => e.semanticEquals(other.e) case _ => false } - override val hashCode: Int = e.semanticHash() + + override def hashCode: Int = e.semanticHash() } // For each expression, the set of equivalent expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e9dda588de8ee498cba1fd48b1d1b6278c73f088..7e3683e482df10336a99ee158601f95c1a43deaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Objects import org.json4s.JsonAST._ @@ -170,6 +171,8 @@ case class Literal protected (value: Any, dataType: DataType) override def toString: String = if (value != null) value.toString else "null" + override def hashCode(): Int = 31 * (31 * Objects.hashCode(dataType)) + Objects.hashCode(value) + override def equals(other: Any): Boolean = other match { case o: Literal => dataType.equals(o.dataType) && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c083f12724dbb302d8f6db0eb640b921fa8971b5..8b38838537275e66bf74fec949141d1a8db02c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.UUID +import java.util.{Objects, UUID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -175,6 +175,11 @@ case class Alias(child: Expression, name: String)( exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil } + override def hashCode(): Int = { + val state = Seq(name, exprId, child, qualifier, explicitMetadata) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + override def equals(other: Any): Boolean = other match { case a: Alias => name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index fb7251d71b9b49bb7f7bc6aba3702cb375970fcb..71a9b9f8082a462a064940b03a665e8a481745ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Objects + import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -83,6 +85,8 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def sql: String = sqlType.sql + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other match { case that: UserDefinedType[_] => this.acceptsType(that) case _ => false @@ -115,7 +119,9 @@ private[sql] class PythonUserDefinedType( } override def equals(other: Any): Boolean = other match { - case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case that: PythonUserDefinedType => pyUDT == that.pyUDT case _ => false } + + override def hashCode(): Int = Objects.hashCode(pyUDT) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 18752014ea908c6a262ac4b4766687ce1ba7821e..c3b20e2cc00a242ef04cf5c77344fbabec70f7b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -35,6 +35,9 @@ import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) case class NestedArray(a: Array[Array[Int]]) { + override def hashCode(): Int = + java.util.Arrays.deepHashCode(a.asInstanceOf[Array[AnyRef]]) + override def equals(other: Any): Boolean = other match { case NestedArray(otherArray) => java.util.Arrays.deepEquals( @@ -64,15 +67,21 @@ case class SpecificCollection(l: List[Int]) /** For testing Kryo serialization based encoder. */ class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value + override def hashCode(): Int = value + + override def equals(other: Any): Boolean = other match { + case that: KryoSerializable => this.value == that.value + case _ => false } } /** For testing Java serialization based encoder. */ class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value + override def hashCode(): Int = value + + override def equals(other: Any): Boolean = other match { + case that: JavaSerializable => this.value == that.value + case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 42891287a3006266a2d69e1fba223ae0030e2f89..e81cd28ea34d100d4e13de859b72a5deeaadd52c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -33,7 +33,10 @@ private final class ShuffledRowRDDPartition( val startPreShufflePartitionIndex: Int, val endPreShufflePartitionIndex: Int) extends Partition { override val index: Int = postShufflePartitionIndex + override def hashCode(): Int = postShufflePartitionIndex + + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 34db10f8225549fc3fe44437fffa26f192772f46..61ec7ed2b15519fe2b7b50b5f2876716ac201d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -44,6 +44,8 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def toString: String = "CSV" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7364a1dc0658abf993660f9b96935ae8181ab1cf..7773ff550fe0a8ab205dc452be246bdf14435e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -154,6 +154,9 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def toString: String = "JSON" + + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bfe7aefe4100c768a5a4007c1a5ba6ebd8cbae71..38c00849529cf4b6d58377d7af8d9ffa525cc00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -60,6 +60,8 @@ private[sql] class DefaultSource override def toString: String = "ParquetFormat" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def prepareWrite( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 92c31eac9594665fe02548322f44f27990a494f3..930adabc48ae43e237b003b8a909e8eabbca579a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,12 +82,12 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr override def value: Long = _value // Needed for SQLListenerSuite - override def equals(other: Any): Boolean = { - other match { - case o: LongSQLMetricValue => value == o.value - case _ => false - } + override def equals(other: Any): Boolean = other match { + case o: LongSQLMetricValue => value == o.value + case _ => false } + + override def hashCode(): Int = _value.hashCode() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 695a5ad78adc62f58e86b0d10ece24c00c42c329..a73e4272950a458278827a252487a17deb57cbc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -27,6 +27,9 @@ import org.apache.spark.sql.types._ */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + + override def hashCode(): Int = 31 * (31 * x.hashCode()) + y.hashCode() + override def equals(other: Any): Boolean = other match { case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index acc9f48d7e08f15ac00027fc0701c262ccb1c600..a49aaa8b73386bb84a1d676daa3f13e2e99bab9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -37,9 +37,10 @@ object UDT { @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def hashCode(): Int = java.util.Arrays.hashCode(data) + override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => - java.util.Arrays.equals(this.data, v.data) + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } } @@ -63,10 +64,9 @@ object UDT { private[spark] override def asNullable: MyDenseVectorUDT = this - override def equals(other: Any): Boolean = other match { - case _: MyDenseVectorUDT => true - case _ => false - } + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 8119d808ffab332bd947faca0dec87e7912bc2f6..58b7031d5ea6aa447e34701bc098d5bdcd5301d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -84,15 +84,19 @@ private[streaming] object MapWithStateRDDRecord { * RDD, and a partitioned keyed-data RDD */ private[streaming] class MapWithStateRDDPartition( - idx: Int, + override val index: Int, @transient private var prevStateRDD: RDD[_], @transient private var partitionedDataRDD: RDD[_]) extends Partition { private[rdd] var previousSessionRDDPartition: Partition = null private[rdd] var partitionedDataRDDPartition: Partition = null - override def index: Int = idx - override def hashCode(): Int = idx + override def hashCode(): Int = index + + override def equals(other: Any): Boolean = other match { + case that: MapWithStateRDDPartition => index == that.index + case _ => false + } @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 784c6525e557725510b8488dd9d6fbf5fbdaf32a..6a861d6f66edf56e8dfe41d01ac48af9056551e6 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -85,6 +85,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { + override def hashCode(): Int = 0 override def equals(other: Any): Boolean = false }