Skip to content
Snippets Groups Projects
Commit fe473595 authored by Doris Xin's avatar Doris Xin Committed by Xiangrui Meng
Browse files

[SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics

For both Scala and Python.

The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`.

`MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`.

Author: Doris Xin <doris.s.xin@gmail.com>

Closes #1911 from dorx/colStats and squashes the following commits:

77b9924 [Doris Xin] developerAPI tag
de9cbbe [Doris Xin] reviewer comments and moved more ser/de
459faba [Doris Xin] colStats in Statistics for both Scala and Python
parent 2bd81263
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
......@@ -48,182 +48,7 @@ import org.apache.spark.util.Utils
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
private val DENSE_VECTOR_MAGIC: Byte = 1
private val SPARSE_VECTOR_MAGIC: Byte = 2
private val DENSE_MATRIX_MAGIC: Byte = 3
private val LABELED_POINT_MAGIC: Byte = 4
private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
require(bytes.length - offset >= 5, "Byte array too short")
val magic = bytes(offset)
if (magic == DENSE_VECTOR_MAGIC) {
deserializeDenseVector(bytes, offset)
} else if (magic == SPARSE_VECTOR_MAGIC) {
deserializeSparseVector(bytes, offset)
} else {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
}
private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
require(bytes.length - offset == 8, "Wrong size byte array for Double")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
bb.getDouble
}
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
val length = bb.getInt()
require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
val db = bb.asDoubleBuffer()
val ans = new Array[Double](length.toInt)
db.get(ans)
Vectors.dense(ans)
}
private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 9, "Byte array too short")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
val size = bb.getInt()
val nonZeros = bb.getInt()
require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
val ib = bb.asIntBuffer()
val indices = new Array[Int](nonZeros)
ib.get(indices)
bb.position(bb.position() + 4 * nonZeros)
val db = bb.asDoubleBuffer()
val values = new Array[Double](nonZeros)
db.get(values)
Vectors.sparse(size, indices, values)
}
/**
* Returns an 8-byte array for the input Double.
*
* Note: we currently do not use a magic byte for double for storage efficiency.
* This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
* The corresponding deserializer, deserializeDouble, needs to be modified as well if the
* serialization scheme changes.
*/
private[python] def serializeDouble(double: Double): Array[Byte] = {
val bytes = new Array[Byte](8)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(double)
bytes
}
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(DENSE_VECTOR_MAGIC)
bb.putInt(len)
val db = bb.asDoubleBuffer()
db.put(doubles)
bytes
}
private def serializeSparseVector(vector: SparseVector): Array[Byte] = {
val nonZeros = vector.indices.length
val bytes = new Array[Byte](9 + 12 * nonZeros)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(SPARSE_VECTOR_MAGIC)
bb.putInt(vector.size)
bb.putInt(nonZeros)
val ib = bb.asIntBuffer()
ib.put(vector.indices)
bb.position(bb.position() + 4 * nonZeros)
val db = bb.asDoubleBuffer()
db.put(vector.values)
bytes
}
private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
case s: SparseVector =>
serializeSparseVector(s)
case _ =>
serializeDenseVector(vector.toArray)
}
private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
val packetLength = bytes.length
if (packetLength < 9) {
throw new IllegalArgumentException("Byte array too short.")
}
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
if (magic != DENSE_MATRIX_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val rows = bb.getInt()
val cols = bb.getInt()
if (packetLength != 9 + 8 * rows * cols) {
throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
}
val db = bb.asDoubleBuffer()
val ans = new Array[Array[Double]](rows.toInt)
for (i <- 0 until rows.toInt) {
ans(i) = new Array[Double](cols.toInt)
db.get(ans(i))
}
ans
}
private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
val rows = doubles.length
var cols = 0
if (rows > 0) {
cols = doubles(0).length
}
val bytes = new Array[Byte](9 + 8 * rows * cols)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(DENSE_MATRIX_MAGIC)
bb.putInt(rows)
bb.putInt(cols)
val db = bb.asDoubleBuffer()
for (i <- 0 until rows) {
db.put(doubles(i))
}
bytes
}
private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
val fb = serializeDoubleVector(p.features)
val bytes = new Array[Byte](1 + 8 + fb.length)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(LABELED_POINT_MAGIC)
bb.putDouble(p.label)
bb.put(fb)
bytes
}
private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
require(bytes.length >= 9, "Byte array too short")
val magic = bytes(0)
if (magic != LABELED_POINT_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
labelBytes.order(ByteOrder.nativeOrder())
val label = labelBytes.asDoubleBuffer().get(0)
LabeledPoint(label, deserializeDoubleVector(bytes, 9))
}
/**
* Loads and serializes labeled points saved with `RDD#saveAsTextFile`.
......@@ -236,17 +61,17 @@ class PythonMLLibAPI extends Serializable {
jsc: JavaSparkContext,
path: String,
minPartitions: Int): JavaRDD[Array[Byte]] =
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint)
MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint)
private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(SerDe.serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
ret
}
......@@ -405,12 +230,12 @@ class PythonMLLibAPI extends Serializable {
def trainNaiveBayes(
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(Vectors.dense(model.labels)))
ret.add(serializeDoubleVector(Vectors.dense(model.pi)))
ret.add(serializeDoubleMatrix(model.theta))
ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels)))
ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi)))
ret.add(SerDe.serializeDoubleMatrix(model.theta))
ret
}
......@@ -423,52 +248,13 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes))
val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
ret
}
/** Unpack a Rating object from an array of bytes */
private def unpackRating(ratingBytes: Array[Byte]): Rating = {
val bb = ByteBuffer.wrap(ratingBytes)
bb.order(ByteOrder.nativeOrder())
val user = bb.getInt()
val product = bb.getInt()
val rating = bb.getDouble()
new Rating(user, product, rating)
}
/** Unpack a tuple of Ints from an array of bytes */
private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
val v1 = bb.getInt()
val v2 = bb.getInt()
(v1, v2)
}
/**
* Serialize a Rating object into an array of bytes.
* It can be deserialized using RatingDeserializer().
*
* @param rate the Rating object to serialize
* @return
*/
private[spark] def serializeRating(rate: Rating): Array[Byte] = {
val len = 3
val bytes = new Array[Byte](4 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putInt(len)
val db = bb.asDoubleBuffer()
db.put(rate.user.toDouble)
db.put(rate.product.toDouble)
db.put(rate.rating)
bytes
}
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
......@@ -481,7 +267,7 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
ALS.train(ratings, rank, iterations, lambda, blocks)
}
......@@ -498,7 +284,7 @@ class PythonMLLibAPI extends Serializable {
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
......@@ -519,7 +305,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
val algo = Algo.fromString(algoStr)
val impurity = Impurities.fromString(impurityStr)
......@@ -545,7 +331,7 @@ class PythonMLLibAPI extends Serializable {
def predictDecisionTreeModel(
model: DecisionTreeModel,
featuresBytes: Array[Byte]): Double = {
val features: Vector = deserializeDoubleVector(featuresBytes)
val features: Vector = SerDe.deserializeDoubleVector(featuresBytes)
model.predict(features)
}
......@@ -559,8 +345,17 @@ class PythonMLLibAPI extends Serializable {
def predictDecisionTreeModel(
model: DecisionTreeModel,
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
model.predict(data).map(serializeDouble)
val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes))
model.predict(data).map(SerDe.serializeDouble)
}
/**
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
* TODO figure out return type.
*/
def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = {
val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_)))
new MultivariateStatisticalSummarySerialized(cStats)
}
/**
......@@ -569,17 +364,17 @@ class PythonMLLibAPI extends Serializable {
* pyspark.
*/
def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = {
val inputMatrix = X.rdd.map(deserializeDoubleVector(_))
val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_))
val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method))
serializeDoubleMatrix(to2dArray(result))
SerDe.serializeDoubleMatrix(SerDe.to2dArray(result))
}
/**
* Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String).
*/
def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = {
val xDeser = x.rdd.map(deserializeDouble(_))
val yDeser = y.rdd.map(deserializeDouble(_))
val xDeser = x.rdd.map(SerDe.deserializeDouble(_))
val yDeser = y.rdd.map(SerDe.deserializeDouble(_))
Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method))
}
......@@ -588,12 +383,6 @@ class PythonMLLibAPI extends Serializable {
if (method == null) CorrelationNames.defaultCorrName else method
}
// Reformat a Matrix into Array[Array[Double]] for serialization
private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
val values = matrix.toArray
Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows))
}
// Used by the *RDD methods to get default seed if not passed in from pyspark
private def getSeedOrDefault(seed: java.lang.Long): Long = {
if (seed == null) Utils.random.nextLong else seed
......@@ -621,7 +410,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble)
RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
}
/**
......@@ -633,7 +422,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble)
RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
}
/**
......@@ -646,7 +435,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble)
RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble)
}
/**
......@@ -659,7 +448,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
}
/**
......@@ -672,7 +461,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
}
/**
......@@ -686,7 +475,256 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
}
}
/**
* :: DeveloperApi ::
* MultivariateStatisticalSummary with Vector fields serialized.
*/
@DeveloperApi
class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary)
extends Serializable {
def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean)
def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance)
def count: Long = summary.count
def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros)
def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max)
def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min)
}
/**
* SerDe utility functions for PythonMLLibAPI.
*/
private[spark] object SerDe extends Serializable {
private val DENSE_VECTOR_MAGIC: Byte = 1
private val SPARSE_VECTOR_MAGIC: Byte = 2
private val DENSE_MATRIX_MAGIC: Byte = 3
private val LABELED_POINT_MAGIC: Byte = 4
private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
require(bytes.length - offset >= 5, "Byte array too short")
val magic = bytes(offset)
if (magic == DENSE_VECTOR_MAGIC) {
deserializeDenseVector(bytes, offset)
} else if (magic == SPARSE_VECTOR_MAGIC) {
deserializeSparseVector(bytes, offset)
} else {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
}
private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
require(bytes.length - offset == 8, "Wrong size byte array for Double")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
bb.getDouble
}
private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
val length = bb.getInt()
require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
val db = bb.asDoubleBuffer()
val ans = new Array[Double](length.toInt)
db.get(ans)
Vectors.dense(ans)
}
private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 9, "Byte array too short")
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
val size = bb.getInt()
val nonZeros = bb.getInt()
require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
val ib = bb.asIntBuffer()
val indices = new Array[Int](nonZeros)
ib.get(indices)
bb.position(bb.position() + 4 * nonZeros)
val db = bb.asDoubleBuffer()
val values = new Array[Double](nonZeros)
db.get(values)
Vectors.sparse(size, indices, values)
}
/**
* Returns an 8-byte array for the input Double.
*
* Note: we currently do not use a magic byte for double for storage efficiency.
* This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
* The corresponding deserializer, deserializeDouble, needs to be modified as well if the
* serialization scheme changes.
*/
private[python] def serializeDouble(double: Double): Array[Byte] = {
val bytes = new Array[Byte](8)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(double)
bytes
}
private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(DENSE_VECTOR_MAGIC)
bb.putInt(len)
val db = bb.asDoubleBuffer()
db.put(doubles)
bytes
}
private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = {
val nonZeros = vector.indices.length
val bytes = new Array[Byte](9 + 12 * nonZeros)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(SPARSE_VECTOR_MAGIC)
bb.putInt(vector.size)
bb.putInt(nonZeros)
val ib = bb.asIntBuffer()
ib.put(vector.indices)
bb.position(bb.position() + 4 * nonZeros)
val db = bb.asDoubleBuffer()
db.put(vector.values)
bytes
}
private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
case s: SparseVector =>
serializeSparseVector(s)
case _ =>
serializeDenseVector(vector.toArray)
}
private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
val packetLength = bytes.length
if (packetLength < 9) {
throw new IllegalArgumentException("Byte array too short.")
}
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
val magic = bb.get()
if (magic != DENSE_MATRIX_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val rows = bb.getInt()
val cols = bb.getInt()
if (packetLength != 9 + 8 * rows * cols) {
throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
}
val db = bb.asDoubleBuffer()
val ans = new Array[Array[Double]](rows.toInt)
for (i <- 0 until rows.toInt) {
ans(i) = new Array[Double](cols.toInt)
db.get(ans(i))
}
ans
}
private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
val rows = doubles.length
var cols = 0
if (rows > 0) {
cols = doubles(0).length
}
val bytes = new Array[Byte](9 + 8 * rows * cols)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(DENSE_MATRIX_MAGIC)
bb.putInt(rows)
bb.putInt(cols)
val db = bb.asDoubleBuffer()
for (i <- 0 until rows) {
db.put(doubles(i))
}
bytes
}
private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
val fb = serializeDoubleVector(p.features)
val bytes = new Array[Byte](1 + 8 + fb.length)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.put(LABELED_POINT_MAGIC)
bb.putDouble(p.label)
bb.put(fb)
bytes
}
private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
require(bytes.length >= 9, "Byte array too short")
val magic = bytes(0)
if (magic != LABELED_POINT_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
labelBytes.order(ByteOrder.nativeOrder())
val label = labelBytes.asDoubleBuffer().get(0)
LabeledPoint(label, deserializeDoubleVector(bytes, 9))
}
// Reformat a Matrix into Array[Array[Double]] for serialization
private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
val values = matrix.toArray
Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows))
}
/** Unpack a Rating object from an array of bytes */
private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = {
val bb = ByteBuffer.wrap(ratingBytes)
bb.order(ByteOrder.nativeOrder())
val user = bb.getInt()
val product = bb.getInt()
val rating = bb.getDouble()
new Rating(user, product, rating)
}
/** Unpack a tuple of Ints from an array of bytes */
def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
val v1 = bb.getInt()
val v2 = bb.getInt()
(v1, v2)
}
/**
* Serialize a Rating object into an array of bytes.
* It can be deserialized using RatingDeserializer().
*
* @param rate the Rating object to serialize
* @return
*/
def serializeRating(rate: Rating): Array[Byte] = {
val len = 3
val bytes = new Array[Byte](4 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putInt(len)
val db = bb.asDoubleBuffer()
db.put(rate.user.toDouble)
db.put(rate.product.toDouble)
db.put(rate.rating)
bytes
}
}
......@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.PythonMLLibAPI
import org.apache.spark.mllib.api.python.SerDe
/**
* Model representing the result of matrix factorization.
......@@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] (
*/
@DeveloperApi
def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes))
predict(usersProducts).map(rate => SerDe.serializeRating(rate))
}
}
......@@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.correlation.Correlations
......@@ -30,6 +31,18 @@ import org.apache.spark.rdd.RDD
@Experimental
object Statistics {
/**
* :: Experimental ::
* Computes column-wise summary statistics for the input RDD[Vector].
*
* @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
* @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics.
*/
@Experimental
def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
new RowMatrix(X).computeColumnSummaryStatistics()
}
/**
* :: Experimental ::
* Compute the Pearson correlation matrix for the input RDD of Vectors.
......
......@@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
class PythonMLLibAPISuite extends FunSuite {
val py = new PythonMLLibAPI
test("vector serialization") {
val vectors = Seq(
......@@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite {
Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
Vectors.sparse(2, Array(1), Array(-2.0)))
vectors.foreach { v =>
val bytes = py.serializeDoubleVector(v)
val u = py.deserializeDoubleVector(bytes)
val bytes = SerDe.serializeDoubleVector(v)
val u = SerDe.deserializeDoubleVector(bytes)
assert(u.getClass === v.getClass)
assert(u === v)
}
......@@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite {
LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])),
LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0))))
points.foreach { p =>
val bytes = py.serializeLabeledPoint(p)
val q = py.deserializeLabeledPoint(bytes)
val bytes = SerDe.serializeLabeledPoint(p)
val q = SerDe.deserializeLabeledPoint(bytes)
assert(q.label === p.label)
assert(q.features.getClass === p.features.getClass)
assert(q.features === p.features)
......@@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite {
test("double serialization") {
for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
val bytes = py.serializeDouble(x)
val deser = py.deserializeDouble(bytes)
val bytes = SerDe.serializeDouble(x)
val deser = SerDe.deserializeDouble(bytes)
// We use `equals` here for comparison because we cannot use `==` for NaN
assert(x.equals(deser))
}
......@@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite {
test("matrix to 2D array") {
val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
val matrix = Matrices.dense(2, 3, values)
val arr = py.to2dArray(matrix)
val arr = SerDe.to2dArray(matrix)
val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8))
assert(arr === expected)
// Test conversion for empty matrix
val empty = Array[Double]()
val emptyMatrix = Matrices.dense(0, 0, empty)
val empty2D = py.to2dArray(emptyMatrix)
val empty2D = SerDe.to2dArray(emptyMatrix)
assert(empty2D === Array[Array[Double]]())
}
}
......@@ -22,11 +22,75 @@ Python package for statistical functions in MLlib.
from pyspark.mllib._common import \
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
_serialize_double, _serialize_double_vector, \
_deserialize_double, _deserialize_double_matrix
_deserialize_double, _deserialize_double_matrix, _deserialize_double_vector
class MultivariateStatisticalSummary(object):
"""
Trait for multivariate statistical summary of a data matrix.
"""
def __init__(self, sc, java_summary):
"""
:param sc: Spark context
:param java_summary: Handle to Java summary object
"""
self._sc = sc
self._java_summary = java_summary
def __del__(self):
self._sc._gateway.detach(self._java_summary)
def mean(self):
return _deserialize_double_vector(self._java_summary.mean())
def variance(self):
return _deserialize_double_vector(self._java_summary.variance())
def count(self):
return self._java_summary.count()
def numNonzeros(self):
return _deserialize_double_vector(self._java_summary.numNonzeros())
def max(self):
return _deserialize_double_vector(self._java_summary.max())
def min(self):
return _deserialize_double_vector(self._java_summary.min())
class Statistics(object):
@staticmethod
def colStats(X):
"""
Computes column-wise summary statistics for the input RDD[Vector].
>>> from linalg import Vectors
>>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
... Vectors.dense([4, 5, 0, 3]),
... Vectors.dense([6, 7, 0, 8])])
>>> cStats = Statistics.colStats(rdd)
>>> cStats.mean()
array([ 4., 4., 0., 3.])
>>> cStats.variance()
array([ 4., 13., 0., 25.])
>>> cStats.count()
3L
>>> cStats.numNonzeros()
array([ 3., 2., 0., 3.])
>>> cStats.max()
array([ 6., 7., 0., 8.])
>>> cStats.min()
array([ 2., 0., 0., -2.])
"""
sc = X.ctx
Xser = _get_unmangled_double_vector_rdd(X)
cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd)
return MultivariateStatisticalSummary(sc, cStats)
@staticmethod
def corr(x, y=None, method=None):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment