Skip to content
Snippets Groups Projects
Commit 5a3533e7 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Reynold Xin
Browse files

[SPARK-15696][SQL] Improve `crosstab` to have a consistent column order

## What changes were proposed in this pull request?

Currently, `crosstab` returns a Dataframe having **random-order** columns obtained by just `distinct`. Also, the documentation of `crosstab` shows the result in a sorted order which is different from the current implementation. This PR explicitly constructs the columns in a sorted order in order to improve user experience. Also, this implementation gives the same result with the documentation.

**Before**
```scala
scala> spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))).toDF("key", "value").stat.crosstab("key", "value").show()
+---------+---+---+---+
|key_value|  3|  2|  1|
+---------+---+---+---+
|        2|  1|  0|  2|
|        1|  0|  1|  1|
|        3|  1|  1|  0|
+---------+---+---+---+

scala> spark.createDataFrame(Seq((1, "a"), (1, "b"), (2, "a"), (2, "a"), (2, "c"), (3, "b"), (3, "c"))).toDF("key", "value").stat.crosstab("key", "value").show()
+---------+---+---+---+
|key_value|  c|  a|  b|
+---------+---+---+---+
|        2|  1|  2|  0|
|        1|  0|  1|  1|
|        3|  1|  0|  1|
+---------+---+---+---+
```

**After**
```scala
scala> spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))).toDF("key", "value").stat.crosstab("key", "value").show()
+---------+---+---+---+
|key_value|  1|  2|  3|
+---------+---+---+---+
|        2|  2|  0|  1|
|        1|  1|  1|  0|
|        3|  0|  1|  1|
+---------+---+---+---+
scala> spark.createDataFrame(Seq((1, "a"), (1, "b"), (2, "a"), (2, "a"), (2, "c"), (3, "b"), (3, "c"))).toDF("key", "value").stat.crosstab("key", "value").show()
+---------+---+---+---+
|key_value|  a|  b|  c|
+---------+---+---+---+
|        2|  2|  0|  1|
|        1|  1|  1|  0|
|        3|  0|  1|  1|
+---------+---+---+---+
```

## How was this patch tested?

Pass the Jenkins tests with updated testcases.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13436 from dongjoon-hyun/SPARK-15696.
parent 6c5fd977
No related branches found
No related tags found
No related merge requests found
......@@ -423,9 +423,9 @@ private[sql] object StatFunctions extends Logging {
def cleanElement(element: Any): String = {
if (element == null) "null" else element.toString
}
// get the distinct values of column 2, so that we can make them the column names
// get the distinct sorted values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] =
counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
......
......@@ -246,8 +246,8 @@ public class JavaDataFrameSuite {
Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
Assert.assertEquals("2", columnNames[1]);
Assert.assertEquals("1", columnNames[2]);
Assert.assertEquals("1", columnNames[1]);
Assert.assertEquals("2", columnNames[2]);
List<Row> rows = crosstab.collectAsList();
Collections.sort(rows, crosstabRowComparator);
Integer count = 1;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment