diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7a84730a6fd9cd1c8940929835059269cfa19ea..e59a483075c949a808b99acfd1fe068ce41daa97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1892,17 +1892,25 @@ class Dataset[T] private[sql]( def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.map { colName => - allColumns.find(col => resolver(col.name, colName)).getOrElse( + val groupCols = colNames.flatMap { colName => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + cols } val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() + // Removing duplicate rows should not change output attributes. We should keep + // the original exprId of the attribute. Otherwise, to select a column in original + // dataset will cause analysis exception due to unresolved attribute. + Alias(new First(attr).toAggregateExpression(), attr.name)(exprId = attr.exprId) } } Aggregate(groupCols, aggCols, logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3243f352a5337f2500f32729a2a57741054fdbb7..5fce9b4fe97ea668dcbec0847d6e632e41de1651 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -872,6 +872,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 1), ("a", 2), ("b", 1)) } + test("dropDuplicates: columns with same column name") { + val ds1 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + val ds2 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + // The dataset joined has two columns of the same name "_2". + val joined = ds1.join(ds2, "_1").select(ds1("_2").as[Int], ds2("_2").as[Int]) + checkDataset( + joined.dropDuplicates(), + (1, 2), (1, 1), (2, 1), (2, 2)) + } + + test("dropDuplicates should not change child plan output") { + val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + checkDataset( + ds.dropDuplicates("_1").select(ds("_1").as[String], ds("_2").as[Int]), + ("a", 1), ("b", 1)) + } + test("SPARK-16097: Encoders.tuple should handle null object correctly") { val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING) val data = Seq((("a", "b"), "c"), (null, "d"))