diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 99321bcc7cf98831edcbaa4aecc2842a809eb0cc..b2dc4fcb61964dbf55bb89777d5f0c1a1aaf6031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -59,6 +59,29 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha @Since("1.6.0") def getHandleInvalid: String = $(handleInvalid) + /** + * Param for how to order labels of string column. The first label after ordering is assigned + * an index of 0. + * Options are: + * - 'frequencyDesc': descending order by label frequency (most frequent label assigned 0) + * - 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0) + * - 'alphabetDesc': descending alphabetical order + * - 'alphabetAsc': ascending alphabetical order + * Default is 'frequencyDesc'. + * + * @group param + */ + @Since("2.3.0") + final val stringOrderType: Param[String] = new Param(this, "stringOrderType", + "how to order labels of string column. " + + "The first label after ordering is assigned an index of 0. " + + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", + ParamValidators.inArray(StringIndexer.supportedStringOrderType)) + + /** @group getParam */ + @Since("2.3.0") + def getStringOrderType: String = $(stringOrderType) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) @@ -79,8 +102,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. - * The indices are in [0, numLabels), ordered by label frequencies. - * So the most frequent label gets index 0. + * The indices are in [0, numLabels). By default, this is ordered by label frequencies + * so the most frequent label gets index 0. The ordering behavior is controlled by + * setting `stringOrderType`. * * @see `IndexToString` for the inverse transformation */ @@ -96,6 +120,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setStringOrderType(value: String): this.type = set(stringOrderType, value) + setDefault(stringOrderType, StringIndexer.frequencyDesc) + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -107,11 +136,17 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) - .rdd - .map(_.getString(0)) - .countByValue() - val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val values = dataset.na.drop(Array($(inputCol))) + .select(col($(inputCol)).cast(StringType)) + .rdd.map(_.getString(0)) + val labels = $(stringOrderType) match { + case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) + .map(_._1).toArray + case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) + .map(_._1).toArray + case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) + case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + } copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -131,6 +166,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + private[feature] val frequencyDesc: String = "frequencyDesc" + private[feature] val frequencyAsc: String = "frequencyAsc" + private[feature] val alphabetDesc: String = "alphabetDesc" + private[feature] val alphabetAsc: String = "alphabetAsc" + private[feature] val supportedStringOrderType: Array[String] = + Array(frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5634d4210f478bde7f660dc33823a410c1d32d81..806a92760c8b6b2ee5fc65d29adef183d0579149 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -291,4 +291,27 @@ class StringIndexerSuite NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") } + + test("StringIndexer order types") { + val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b")) + val df = data.toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + + var idx = 0 + for (orderType <- StringIndexer.supportedStringOrderType) { + val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) + val output = transformed.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + assert(output === expected(idx)) + idx += 1 + } + } }