diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 1d42d49a8816badbac328fb883db4e5dfb3e51ad..129d7d68f7cbbcb32e49c4bbe16f283bc8f9c83d 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -19,7 +19,7 @@ DataFrame-based machine learning APIs to let users quickly assemble and configure practical machine learning pipelines. """ -from pyspark.ml.base import Estimator, Model, Transformer +from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 339e5d6af52a77ef1104f927530e314a9a53f925..a6767cee9bf287ab3537651e1b45b0bf06562c55 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -17,9 +17,14 @@ from abc import ABCMeta, abstractmethod +import copy + from pyspark import since from pyspark.ml.param import Params +from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql.functions import udf +from pyspark.sql.types import StructField, StructType, DoubleType @inherit_doc @@ -116,3 +121,54 @@ class Model(Transformer): """ __metaclass__ = ABCMeta + + +@inherit_doc +class UnaryTransformer(HasInputCol, HasOutputCol, Transformer): + """ + Abstract class for transformers that take one input column, apply transformation, + and output the result as a new column. + + .. versionadded:: 2.3.0 + """ + + @abstractmethod + def createTransformFunc(self): + """ + Creates the transform function using the given param map. The input param map already takes + account of the embedded param map. So the param values should be determined + solely by the input param map. + """ + raise NotImplementedError() + + @abstractmethod + def outputDataType(self): + """ + Returns the data type of the output column. + """ + raise NotImplementedError() + + @abstractmethod + def validateInputType(self, inputType): + """ + Validates the input type. Throw an exception if it is invalid. + """ + raise NotImplementedError() + + def transformSchema(self, schema): + inputType = schema[self.getInputCol()].dataType + self.validateInputType(inputType) + if self.getOutputCol() in schema.names: + raise ValueError("Output column %s already exists." % self.getOutputCol()) + outputFields = copy.copy(schema.fields) + outputFields.append(StructField(self.getOutputCol(), + self.outputDataType(), + nullable=False)) + return StructType(outputFields) + + def _transform(self, dataset): + self.transformSchema(dataset.schema) + transformUDF = udf(self.createTransformFunc(), self.outputDataType()) + transformedDataset = dataset.withColumn(self.getOutputCol(), + transformUDF(dataset[self.getInputCol()])) + return transformedDataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7ee2c2f3ffe76a508a9f6b443662dfdd9722fe4a..3bd4d3737a056ca9d219386577688339fa502ac5 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -45,7 +45,7 @@ from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer +from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java @@ -66,6 +66,7 @@ from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand +from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -121,6 +122,36 @@ class MockTransformer(Transformer, HasFake): return dataset +class MockUnaryTransformer(UnaryTransformer): + + shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + "data in a DataFrame", + typeConverter=TypeConverters.toFloat) + + def __init__(self, shiftVal=1): + super(MockUnaryTransformer, self).__init__() + self._setDefault(shift=1) + self._set(shift=shiftVal) + + def getShift(self): + return self.getOrDefault(self.shift) + + def setShift(self, shift): + self._set(shift=shift) + + def createTransformFunc(self): + shiftVal = self.getShift() + return lambda x: x + shiftVal + + def outputDataType(self): + return DoubleType() + + def validateInputType(self, inputType): + if inputType != DoubleType(): + raise TypeError("Bad input type: {}. ".format(inputType) + + "Requires Integer.") + + class MockEstimator(Estimator, HasFake): def __init__(self): @@ -2008,6 +2039,35 @@ class ChiSquareTestTests(SparkSessionTestCase): self.assertTrue(all(field in fieldNames for field in expectedFields)) +class UnaryTransformerTests(SparkSessionTestCase): + + def test_unary_transformer_validate_input_type(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal)\ + .setInputCol("input").setOutputCol("output") + + # should not raise any errors + transformer.validateInputType(DoubleType()) + + with self.assertRaises(TypeError): + # passing the wrong input type should raise an error + transformer.validateInputType(IntegerType()) + + def test_unary_transformer_transform(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal)\ + .setInputCol("input").setOutputCol("output") + + df = self.spark.range(0, 10).toDF('input') + df = df.withColumn("input", df.input.cast(dataType="double")) + + transformed_df = transformer.transform(df) + results = transformed_df.select("input", "output").collect() + + for res in results: + self.assertEqual(res.input + shiftVal, res.output) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: