Skip to content
Snippets Groups Projects
Commit bec6b16c authored by zero323's avatar zero323 Committed by Joseph K. Bradley
Browse files

[SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth

## What changes were proposed in this pull request?

Replaces `featuresCol` `Param` with `itemsCol`. See [SPARK-19899](https://issues.apache.org/jira/browse/SPARK-19899).

## How was this patch tested?

Manual tests. Existing unit tests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17321 from zero323/SPARK-19899.
parent fc755459
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.HasPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
......@@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
/**
* Common params for FPGrowth and FPGrowthModel
*/
private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
/**
* Items column name.
* Default: "items"
* @group param
*/
@Since("2.2.0")
val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")
setDefault(itemsCol -> "items")
/** @group getParam */
@Since("2.2.0")
def getItemsCol: String = $(itemsCol)
/**
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
......@@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
*/
@Since("2.2.0")
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputType = schema($(featuresCol)).dataType
val inputType = schema($(itemsCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType)
}
}
......@@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (
/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setItemsCol(value: String): this.type = set(itemsCol, value)
/** @group setParam */
@Since("2.2.0")
......@@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
}
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
val data = dataset.select($(featuresCol))
val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val data = dataset.select($(itemsCol))
val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
if (isSet(numPartitions)) {
mllibFP.setNumPartitions($(numPartitions))
......@@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
val schema = StructType(Seq(
StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
......@@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (
/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setItemsCol(value: String): this.type = set(itemsCol, value)
/** @group setParam */
@Since("2.2.0")
......@@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
.collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]]
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
val dt = dataset.schema($(featuresCol)).dataType
val dt = dataset.schema($(itemsCol)).dataType
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[_]) => {
if (items != null) {
......@@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
} else {
Seq.empty
}}, dt)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
}
@Since("2.2.0")
......
......@@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("FPGrowth fit and transform with different data types") {
Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
val data = dataset.withColumn("features", col("features").cast(ArrayType(dt)))
val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
val expectedRules = spark.createDataFrame(Seq(
......@@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "3"), Array(2))
)).toDF("id", "features", "prediction")
.withColumn("features", col("features").cast(ArrayType(dt)))
)).toDF("id", "items", "prediction")
.withColumn("items", col("items").cast(ArrayType(dt)))
.withColumn("prediction", col("prediction").cast(ArrayType(dt)))
assert(expectedTransformed.collect().toSet.equals(
transformed.collect().toSet))
......@@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(1, Array("1", "2", "3", "5")),
(2, Array("1", "2", "3", "4")),
(3, null.asInstanceOf[Array[String]])
)).toDF("id", "features")
)).toDF("id", "items")
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
val prediction = model.transform(df)
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
......@@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("features")
).map(Tuple1(_))).toDF("items")
val model = new FPGrowth().fit(dataset)
val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
).first().getAs[Seq[String]]("prediction")
assert(prediction === Seq("3"))
......@@ -127,7 +127,7 @@ object FPGrowthSuite {
(0, Array("1", "2")),
(0, Array("1", "2")),
(0, Array("1", "3"))
)).toDF("id", "features")
)).toDF("id", "items")
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment