Skip to content
Snippets Groups Projects
Commit edb23f9e authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Yanbo Liang
Browse files

[SPARK-15946][MLLIB] Conversion between old/new vector columns in a DataFrame (Python)

## What changes were proposed in this pull request?

This PR implements python wrappers for #13662 to convert old/new vector columns in a DataFrame.

## How was this patch tested?

doctest in Python

cc: yanboliang

Author: Xiangrui Meng <meng@databricks.com>

Closes #13731 from mengxr/SPARK-15946.
parent af2a4b08
No related branches found
No related tags found
No related merge requests found
......@@ -1201,6 +1201,20 @@ private[python] class PythonMLLibAPI extends Serializable {
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
spark.createDataFrame(blockMatrix.blocks)
}
/**
* Python-friendly version of [[MLUtils.convertVectorColumnsToML()]].
*/
def convertVectorColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = {
MLUtils.convertVectorColumnsToML(dataset, cols.asScala: _*)
}
/**
* Python-friendly version of [[MLUtils.convertVectorColumnsFromML()]]
*/
def convertVectorColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = {
MLUtils.convertVectorColumnsFromML(dataset, cols.asScala: _*)
}
}
/**
......
......@@ -26,6 +26,7 @@ if sys.version > '3':
from pyspark import SparkContext, since
from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.sql import DataFrame
class MLUtils(object):
......@@ -200,6 +201,86 @@ class MLUtils(object):
"""
return callMLlibFunc("loadVectors", sc, path)
@staticmethod
@since("2.0.0")
def convertVectorColumnsToML(dataset, *cols):
"""
Converts vector columns in an input DataFrame from the
:py:class:`pyspark.mllib.linalg.Vector` type to the new
:py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml`
package.
:param dataset:
input dataset
:param cols:
a list of vector columns to be converted.
New vector columns will be ignored. If unspecified, all old
vector columns will be converted excepted nested ones.
:return:
the input dataset with old vector columns converted to the
new vector type
>>> import pyspark
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.mllib.util import MLUtils
>>> df = spark.createDataFrame(
... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))],
... ["id", "x", "y"])
>>> r1 = MLUtils.convertVectorColumnsToML(df).first()
>>> isinstance(r1.x, pyspark.ml.linalg.SparseVector)
True
>>> isinstance(r1.y, pyspark.ml.linalg.DenseVector)
True
>>> r2 = MLUtils.convertVectorColumnsToML(df, "x").first()
>>> isinstance(r2.x, pyspark.ml.linalg.SparseVector)
True
>>> isinstance(r2.y, pyspark.mllib.linalg.DenseVector)
True
"""
if not isinstance(dataset, DataFrame):
raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset)))
return callMLlibFunc("convertVectorColumnsToML", dataset, list(cols))
@staticmethod
@since("2.0.0")
def convertVectorColumnsFromML(dataset, *cols):
"""
Converts vector columns in an input DataFrame to the
:py:class:`pyspark.mllib.linalg.Vector` type from the new
:py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml`
package.
:param dataset:
input dataset
:param cols:
a list of vector columns to be converted.
Old vector columns will be ignored. If unspecified, all new
vector columns will be converted except nested ones.
:return:
the input dataset with new vector columns converted to the
old vector type
>>> import pyspark
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.mllib.util import MLUtils
>>> df = spark.createDataFrame(
... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))],
... ["id", "x", "y"])
>>> r1 = MLUtils.convertVectorColumnsFromML(df).first()
>>> isinstance(r1.x, pyspark.mllib.linalg.SparseVector)
True
>>> isinstance(r1.y, pyspark.mllib.linalg.DenseVector)
True
>>> r2 = MLUtils.convertVectorColumnsFromML(df, "x").first()
>>> isinstance(r2.x, pyspark.mllib.linalg.SparseVector)
True
>>> isinstance(r2.y, pyspark.ml.linalg.DenseVector)
True
"""
if not isinstance(dataset, DataFrame):
raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset)))
return callMLlibFunc("convertVectorColumnsFromML", dataset, list(cols))
class Saveable(object):
"""
......@@ -355,6 +436,7 @@ def _test():
.master("local[2]")\
.appName("mllib.util tests")\
.getOrCreate()
globs['spark'] = spark
globs['sc'] = spark.sparkContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment