diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b6f7618171224daab49e2831781c0f9475083ca1..f04df1c156898d7fb4973365e9f69b895894c30b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable { learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeights: Vector): JList[Object] = { - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - learner.disableUncachedWarning() - val model = learner.run(data.rdd, initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + try { + val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Return the Updater from string + */ + def getUpdaterFromString(regType: String): Updater = { + if (regType == "l2") { + new SquaredL2Updater + } else if (regType == "l1") { + new L1Updater + } else if (regType == null || regType == "none") { + new SimpleUpdater + } else { + throw new IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } } /** @@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - lrAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - lrAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, data, @@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - SVMAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - SVMAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, data, @@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - LogRegAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setNumCorrections(corrections) .setConvergenceTol(tolerance) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType == null) { - LogRegAlg.optimizer.setUpdater(new SimpleUpdater) - } else { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: ['l1', 'l2', None].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable { .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - .disableUncachedWarning() - kMeansAlg.run(data.rdd) + try { + kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable { numPartitions: Int, numIterations: Int, seed: Long): Word2VecModelWrapper = { - val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) - val model = word2vec.fit(data) - data.unpersist() - new Word2VecModelWrapper(model) + try { + val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) + new Word2VecModelWrapper(model) + } finally { + dataJRDD.rdd.unpersist(blocking = false) + } } private[python] class Word2VecModelWrapper(model: Word2VecModel) { @@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) - - DecisionTree.train(data.rdd, strategy) + try { + DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable { numClassesForClassification = numClasses, maxBins = maxBins, categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) - if (algo == Algo.Classification) { - RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) - } else { - RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + if (algo == Algo.Classification) { + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + } + } finally { + cached.unpersist(blocking = false) } } @@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes()) + out.write((module + "\n" + name + "\n").getBytes) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7443f232ec3e736a7079dfad037ba3db4cd94fda..34ea0de706f0821108d49adb9e981440fb81e1bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -113,22 +113,13 @@ class KMeans private ( this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -143,7 +134,7 @@ class KMeans private ( norms.unpersist() // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 00dfc86c9e0bdaf83358cadc9ef00596cdee1382..0287f04e2c7777a80fa791f639f31d64448bf69a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index fe4c4cc5094d864810e6cffc9ceb1e3d1596266e..e2492eef5bd6a9a1fa70a764c757fe7f87c1b812 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,7 +16,7 @@ # from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, callJavaFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,10 +80,8 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - # cache serialized data to avoid objects over head in JVM - jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True) - model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs, - initializationMode) + model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, + runs, initializationMode) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index c6149fe391ec8e7fdd8cea36db56995fd30042bb..33c49e2399908f7a6ae01e5522bff216f3b8732a 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -54,15 +54,13 @@ _picklable_classes = [ # this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd, cache=False): +def _to_java_object_rdd(rdd): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - if cache: - rdd.cache() return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 2bcbf2aaf8e3e99b8ac3991de1808fd464479d47..97ec74eda0b713961881f4f2f2f638cb1c127780 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,7 @@ from collections import namedtuple from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -110,7 +110,7 @@ class ALS(object): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - return _to_java_object_rdd(ratings, True) + return ratings @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index f4f5e615fadc3fd2667f9159d36493c03c0f873c..210060140fd91ac911ad3b6d1d1d2b123e5a0871 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', @@ -129,8 +129,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) - weights, intercept = train_func(_to_java_object_rdd(data, cache=True), - _convert_to_vector(initial_weights)) + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) return modelClass(weights, intercept)