Skip to content
Snippets Groups Projects
Commit d4a637cd authored by zero323's avatar zero323 Committed by Joseph K. Bradley
Browse files

[SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items

## What changes were proposed in this pull request?

This commit moved `distinct` in its intended place to avoid duplicated predictions and adds unit test covering the issue.

## How was this patch tested?

Unit tests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17283 from zero323/SPARK-19940.
parent 5e96a57b
No related branches found
No related tags found
No related merge requests found
......@@ -245,10 +245,10 @@ class FPGrowthModel private[ml] (
rule._2.filter(item => !itemset.contains(item))
} else {
Seq.empty
})
}).distinct
} else {
Seq.empty
}.distinct }, dt)
}}, dt)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
......
......@@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
FPGrowthSuite.allParamSettings, checkModelData)
}
test("FPGrowth prediction should not contain duplicates") {
// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("features")
val model = new FPGrowth().fit(dataset)
val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
).first().getAs[Seq[String]]("prediction")
assert(prediction === Seq("3"))
}
}
object FPGrowthSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment