diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index 917861309c573836be40148893d92446a8faabb6..37f173bc20469539bac9c85745776b3befa43cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -27,17 +27,7 @@ import org.apache.spark.sql.types._ */ private[spark] class VectorUDT extends UserDefinedType[Vector] { - override def sqlType: StructType = { - // type: 0 = sparse, 1 = dense - // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse - // vectors. The "values" field is nullable because we might want to add binary vectors later, - // which uses "size" and "indices", but not "values". - StructType(Seq( - StructField("type", ByteType, nullable = false), - StructField("size", IntegerType, nullable = true), - StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), - StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) - } + override final def sqlType: StructType = _sqlType override def serialize(obj: Vector): InternalRow = { obj match { @@ -94,4 +84,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def typeName: String = "vector" private[spark] override def asNullable: VectorUDT = this + + private[this] val _sqlType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala new file mode 100644 index 0000000000000000000000000000000000000000..7e408b9dbd13aba7841a9b3440480a7a5562beb6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -0,0 +1,596 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import java.io._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ + +/** + * A builder object that provides summary statistics about a given column. + * + * Users should not directly create such builders, but instead use one of the methods in + * [[Summarizer]]. + */ +@Experimental +@Since("2.3.0") +sealed abstract class SummaryBuilder { + /** + * Returns an aggregate object that contains the summary of the column with the requested metrics. + * @param featuresCol a column that contains features Vector object. + * @param weightCol a column that contains weight value. + * @return an aggregate column that contains the statistics. The exact content of this + * structure is determined during the creation of the builder. + */ + @Since("2.3.0") + def summary(featuresCol: Column, weightCol: Column): Column + + @Since("2.3.0") + def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0)) +} + +/** + * Tools for vectorized statistics on MLlib Vectors. + * + * The methods in this package provide various statistics for Vectors contained inside DataFrames. + * + * This class lets users pick the statistics they would like to extract for a given column. Here is + * an example in Scala: + * {{{ + * val dataframe = ... // Some dataframe containing a feature column + * val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features")) + * val Row(Row(min_, max_)) = allStats.first() + * }}} + * + * If one wants to get a single metric, shortcuts are also available: + * {{{ + * val meanDF = dataframe.select(Summarizer.mean($"features")) + * val Row(mean_) = meanDF.first() + * }}} + * + * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * interface. + */ +@Experimental +@Since("2.3.0") +object Summarizer extends Logging { + + import SummaryBuilderImpl._ + + /** + * Given a list of metrics, provides a builder that it turns computes metrics from a column. + * + * See the documentation of [[Summarizer]] for an example. + * + * The following metrics are accepted (case sensitive): + * - mean: a vector that contains the coefficient-wise mean. + * - variance: a vector tha contains the coefficient-wise variance. + * - count: the count of all vectors seen. + * - numNonzeros: a vector with the number of non-zeros for each coefficients + * - max: the maximum for each coefficient. + * - min: the minimum for each coefficient. + * - normL2: the Euclidian norm for each coefficient. + * - normL1: the L1 norm of each coefficient (sum of the absolute values). + * @param firstMetric the metric being provided + * @param metrics additional metrics that can be provided. + * @return a builder. + * @throws IllegalArgumentException if one of the metric names is not understood. + * + * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * interface. + */ + @Since("2.3.0") + def metrics(firstMetric: String, metrics: String*): SummaryBuilder = { + val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics) + new SummaryBuilderImpl(typedMetrics, computeMetrics) + } + + @Since("2.3.0") + def mean(col: Column): Column = getSingleMetric(col, "mean") + + @Since("2.3.0") + def variance(col: Column): Column = getSingleMetric(col, "variance") + + @Since("2.3.0") + def count(col: Column): Column = getSingleMetric(col, "count") + + @Since("2.3.0") + def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros") + + @Since("2.3.0") + def max(col: Column): Column = getSingleMetric(col, "max") + + @Since("2.3.0") + def min(col: Column): Column = getSingleMetric(col, "min") + + @Since("2.3.0") + def normL1(col: Column): Column = getSingleMetric(col, "normL1") + + @Since("2.3.0") + def normL2(col: Column): Column = getSingleMetric(col, "normL2") + + private def getSingleMetric(col: Column, metric: String): Column = { + val c1 = metrics(metric).summary(col) + c1.getField(metric).as(s"$metric($col)") + } +} + +private[ml] class SummaryBuilderImpl( + requestedMetrics: Seq[SummaryBuilderImpl.Metric], + requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetric] + ) extends SummaryBuilder { + + override def summary(featuresCol: Column, weightCol: Column): Column = { + + val agg = SummaryBuilderImpl.MetricsAggregate( + requestedMetrics, + requestedCompMetrics, + featuresCol.expr, + weightCol.expr, + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + + new Column(AggregateExpression(agg, mode = Complete, isDistinct = false)) + } +} + +private[ml] object SummaryBuilderImpl extends Logging { + + def implementedMetrics: Seq[String] = allMetrics.map(_._1).sorted + + @throws[IllegalArgumentException]("When the list is empty or not a subset of known metrics") + def getRelevantMetrics(requested: Seq[String]): (Seq[Metric], Seq[ComputeMetric]) = { + val all = requested.map { req => + val (_, metric, _, deps) = allMetrics.find(_._1 == req).getOrElse { + throw new IllegalArgumentException(s"Metric $req cannot be found." + + s" Valid metrics are $implementedMetrics") + } + metric -> deps + } + // Do not sort, otherwise the user has to look the schema to see the order that it + // is going to be given in. + val metrics = all.map(_._1) + val computeMetrics = all.flatMap(_._2).distinct.sortBy(_.toString) + metrics -> computeMetrics + } + + def structureForMetrics(metrics: Seq[Metric]): StructType = { + val dict = allMetrics.map { case (name, metric, dataType, _) => + (metric, (name, dataType)) + }.toMap + val fields = metrics.map(dict.apply).map { case (name, dataType) => + StructField(name, dataType, nullable = false) + } + StructType(fields) + } + + private val arrayDType = ArrayType(DoubleType, containsNull = false) + private val arrayLType = ArrayType(LongType, containsNull = false) + + /** + * All the metrics that can be currently computed by Spark for vectors. + * + * This list associates the user name, the internal (typed) name, and the list of computation + * metrics that need to de computed internally to get the final result. + */ + private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq( + ("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)), + ("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), + ("count", Count, LongType, Seq()), + ("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)), + ("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)), + ("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)), + ("normL2", NormL2, arrayDType, Seq(ComputeM2)), + ("normL1", NormL1, arrayDType, Seq(ComputeL1)) + ) + + /** + * The metrics that are currently implemented. + */ + sealed trait Metric extends Serializable + private[stat] case object Mean extends Metric + private[stat] case object Variance extends Metric + private[stat] case object Count extends Metric + private[stat] case object NumNonZeros extends Metric + private[stat] case object Max extends Metric + private[stat] case object Min extends Metric + private[stat] case object NormL2 extends Metric + private[stat] case object NormL1 extends Metric + + /** + * The running metrics that are going to be computed. + * + * There is a bipartite graph between the metrics and the computed metrics. + */ + sealed trait ComputeMetric extends Serializable + private[stat] case object ComputeMean extends ComputeMetric + private[stat] case object ComputeM2n extends ComputeMetric + private[stat] case object ComputeM2 extends ComputeMetric + private[stat] case object ComputeL1 extends ComputeMetric + private[stat] case object ComputeWeightSum extends ComputeMetric + private[stat] case object ComputeNNZ extends ComputeMetric + private[stat] case object ComputeMax extends ComputeMetric + private[stat] case object ComputeMin extends ComputeMetric + + private[stat] class SummarizerBuffer( + requestedMetrics: Seq[Metric], + requestedCompMetrics: Seq[ComputeMetric] + ) extends Serializable { + + private var n = 0 + private var currMean: Array[Double] = null + private var currM2n: Array[Double] = null + private var currM2: Array[Double] = null + private var currL1: Array[Double] = null + private var totalCnt: Long = 0 + private var totalWeightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 + private var weightSum: Array[Double] = null + private var nnz: Array[Long] = null + private var currMax: Array[Double] = null + private var currMin: Array[Double] = null + + def this() { + this( + Seq(Mean, Variance, Count, NumNonZeros, Max, Min, NormL2, NormL1), + Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1, + ComputeWeightSum, ComputeNNZ, ComputeMax, ComputeMin) + ) + } + + /** + * Add a new sample to this summarizer, and update the statistical summary. + */ + def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, $weight has to be >= 0.0") + if (weight == 0.0) return this + + if (n == 0) { + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size + + if (requestedCompMetrics.contains(ComputeMean)) { currMean = Array.ofDim[Double](n) } + if (requestedCompMetrics.contains(ComputeM2n)) { currM2n = Array.ofDim[Double](n) } + if (requestedCompMetrics.contains(ComputeM2)) { currM2 = Array.ofDim[Double](n) } + if (requestedCompMetrics.contains(ComputeL1)) { currL1 = Array.ofDim[Double](n) } + if (requestedCompMetrics.contains(ComputeWeightSum)) { weightSum = Array.ofDim[Double](n) } + if (requestedCompMetrics.contains(ComputeNNZ)) { nnz = Array.ofDim[Long](n) } + if (requestedCompMetrics.contains(ComputeMax)) { + currMax = Array.fill[Double](n)(Double.MinValue) + } + if (requestedCompMetrics.contains(ComputeMin)) { + currMin = Array.fill[Double](n)(Double.MaxValue) + } + } + + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") + + val localCurrMean = currMean + val localCurrM2n = currM2n + val localCurrM2 = currM2 + val localCurrL1 = currL1 + val localWeightSum = weightSum + val localNumNonzeros = nnz + val localCurrMax = currMax + val localCurrMin = currMin + instance.foreachActive { (index, value) => + if (value != 0.0) { + if (localCurrMax != null && localCurrMax(index) < value) { + localCurrMax(index) = value + } + if (localCurrMin != null && localCurrMin(index) > value) { + localCurrMin(index) = value + } + + if (localWeightSum != null) { + if (localCurrMean != null) { + val prevMean = localCurrMean(index) + val diff = value - prevMean + localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight) + + if (localCurrM2n != null) { + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + } + } + localWeightSum(index) += weight + } + + if (localCurrM2 != null) { + localCurrM2(index) += weight * value * value + } + if (localCurrL1 != null) { + localCurrL1(index) += weight * math.abs(value) + } + + if (localNumNonzeros != null) { + localNumNonzeros(index) += 1 + } + } + } + + totalWeightSum += weight + weightSquareSum += weight * weight + totalCnt += 1 + this + } + + def add(instance: Vector): this.type = add(instance, 1.0) + + /** + * Merge another SummarizerBuffer, and update the statistical summary. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other MultivariateOnlineSummarizer to be merged. + */ + def merge(other: SummarizerBuffer): this.type = { + if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) { + require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + + s"Expecting $n but got ${other.n}.") + totalCnt += other.totalCnt + totalWeightSum += other.totalWeightSum + weightSquareSum += other.weightSquareSum + var i = 0 + while (i < n) { + if (weightSum != null) { + val thisWeightSum = weightSum(i) + val otherWeightSum = other.weightSum(i) + val totalWeightSum = thisWeightSum + otherWeightSum + + if (totalWeightSum != 0.0) { + if (currMean != null) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherWeightSum / totalWeightSum + + if (currM2n != null) { + // merge m2n together + currM2n(i) += other.currM2n(i) + + deltaMean * deltaMean * thisWeightSum * otherWeightSum / totalWeightSum + } + } + } + weightSum(i) = totalWeightSum + } + + // merge m2 together + if (currM2 != null) { currM2(i) += other.currM2(i) } + // merge l1 together + if (currL1 != null) { currL1(i) += other.currL1(i) } + // merge max and min + if (currMax != null) { currMax(i) = math.max(currMax(i), other.currMax(i)) } + if (currMin != null) { currMin(i) = math.min(currMin(i), other.currMin(i)) } + if (nnz != null) { nnz(i) = nnz(i) + other.nnz(i) } + i += 1 + } + } else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) { + this.n = other.n + if (other.currMean != null) { this.currMean = other.currMean.clone() } + if (other.currM2n != null) { this.currM2n = other.currM2n.clone() } + if (other.currM2 != null) { this.currM2 = other.currM2.clone() } + if (other.currL1 != null) { this.currL1 = other.currL1.clone() } + this.totalCnt = other.totalCnt + this.totalWeightSum = other.totalWeightSum + this.weightSquareSum = other.weightSquareSum + if (other.weightSum != null) { this.weightSum = other.weightSum.clone() } + if (other.nnz != null) { this.nnz = other.nnz.clone() } + if (other.currMax != null) { this.currMax = other.currMax.clone() } + if (other.currMin != null) { this.currMin = other.currMin.clone() } + } + this + } + + /** + * Sample mean of each dimension. + */ + def mean: Vector = { + require(requestedMetrics.contains(Mean)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + val realMean = Array.ofDim[Double](n) + var i = 0 + while (i < n) { + realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) + i += 1 + } + Vectors.dense(realMean) + } + + /** + * Unbiased estimate of sample variance of each dimension. + */ + def variance: Vector = { + require(requestedMetrics.contains(Variance)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + val realVariance = Array.ofDim[Double](n) + + val denominator = totalWeightSum - (weightSquareSum / totalWeightSum) + + // Sample variance is computed, if the denominator is less than 0, the variance is just 0. + if (denominator > 0.0) { + val deltaMean = currMean + var i = 0 + val len = currM2n.length + while (i < len) { + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + i += 1 + } + } + Vectors.dense(realVariance) + } + + /** + * Sample size. + */ + def count: Long = totalCnt + + /** + * Number of nonzero elements in each dimension. + * + */ + def numNonzeros: Vector = { + require(requestedMetrics.contains(NumNonZeros)) + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + Vectors.dense(nnz.map(_.toDouble)) + } + + /** + * Maximum value of each dimension. + */ + def max: Vector = { + require(requestedMetrics.contains(Max)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + i += 1 + } + Vectors.dense(currMax) + } + + /** + * Minimum value of each dimension. + */ + def min: Vector = { + require(requestedMetrics.contains(Min)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + i += 1 + } + Vectors.dense(currMin) + } + + /** + * L2 (Euclidian) norm of each dimension. + */ + def normL2: Vector = { + require(requestedMetrics.contains(NormL2)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + val realMagnitude = Array.ofDim[Double](n) + + var i = 0 + val len = currM2.length + while (i < len) { + realMagnitude(i) = math.sqrt(currM2(i)) + i += 1 + } + Vectors.dense(realMagnitude) + } + + /** + * L1 norm of each dimension. + */ + def normL1: Vector = { + require(requestedMetrics.contains(NormL1)) + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + + Vectors.dense(currL1) + } + } + + private case class MetricsAggregate( + requestedMetrics: Seq[Metric], + requestedComputeMetrics: Seq[ComputeMetric], + featuresExpr: Expression, + weightExpr: Expression, + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends TypedImperativeAggregate[SummarizerBuffer] { + + override def eval(state: SummarizerBuffer): InternalRow = { + val metrics = requestedMetrics.map { + case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray) + case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray) + case Count => state.count + case NumNonZeros => UnsafeArrayData.fromPrimitiveArray( + state.numNonzeros.toArray.map(_.toLong)) + case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray) + case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray) + case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray) + case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray) + } + InternalRow.apply(metrics: _*) + } + + override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil + + override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { + val features = udt.deserialize(featuresExpr.eval(row)) + val weight = weightExpr.eval(row).asInstanceOf[Double] + state.add(features, weight) + state + } + + override def merge(state: SummarizerBuffer, + other: SummarizerBuffer): SummarizerBuffer = { + state.merge(other) + } + + override def nullable: Boolean = false + + override def createAggregationBuffer(): SummarizerBuffer + = new SummarizerBuffer(requestedMetrics, requestedComputeMetrics) + + override def serialize(state: SummarizerBuffer): Array[Byte] = { + // TODO: Use ByteBuffer to optimize + val bos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(bos) + oos.writeObject(state) + bos.toByteArray + } + + override def deserialize(bytes: Array[Byte]): SummarizerBuffer = { + // TODO: Use ByteBuffer to optimize + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + ois.readObject().asInstanceOf[SummarizerBuffer] + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): MetricsAggregate = { + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): MetricsAggregate = { + copy(inputAggBufferOffset = newInputAggBufferOffset) + } + + override lazy val dataType: DataType = structureForMetrics(requestedMetrics) + + override def prettyName: String = "aggregate_metrics" + + } + + private[this] val udt = new VectorUDT + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..dfb733ff6e761388cd2c7f4290314484c8af7ead --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema + +class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import testImplicits._ + import Summarizer._ + import SummaryBuilderImpl._ + + private case class ExpectedMetrics( + mean: Seq[Double], + variance: Seq[Double], + count: Long, + numNonZeros: Seq[Long], + max: Seq[Double], + min: Seq[Double], + normL2: Seq[Double], + normL1: Seq[Double]) + + /** + * The input is expected to be either a sparse vector, a dense vector or an array of doubles + * (which will be converted to a dense vector) + * The expected is the list of all the known metrics. + * + * The tests take an list of input vectors and a list of all the summary values that + * are expected for this input. They currently test against some fixed subset of the + * metrics, but should be made fuzzy in the future. + */ + private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = { + + def inputVec: Seq[Vector] = input.map { + case x: Array[Double @unchecked] => Vectors.dense(x) + case x: Seq[Double @unchecked] => Vectors.dense(x.toArray) + case x: Vector => x + case x => throw new Exception(x.toString) + } + + val summarizer = { + val _summarizer = new MultivariateOnlineSummarizer + inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v))) + _summarizer + } + + // Because the Spark context is reset between tests, we cannot hold a reference onto it. + def wrappedInit() = { + val df = inputVec.map(Tuple1.apply).toDF("features") + val col = df.col("features") + (df, col) + } + + registerTest(s"$name - mean only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), summarizer.mean)) + } + + registerTest(s"$name - mean only (direct)") { + val (df, c) = wrappedInit() + compare(df.select(mean(c)), Seq(exp.mean)) + } + + registerTest(s"$name - variance only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("variance").summary(c), variance(c)), + Seq(Row(exp.variance), summarizer.variance)) + } + + registerTest(s"$name - variance only (direct)") { + val (df, c) = wrappedInit() + compare(df.select(variance(c)), Seq(summarizer.variance)) + } + + registerTest(s"$name - count only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("count").summary(c), count(c)), + Seq(Row(exp.count), exp.count)) + } + + registerTest(s"$name - count only (direct)") { + val (df, c) = wrappedInit() + compare(df.select(count(c)), + Seq(exp.count)) + } + + registerTest(s"$name - numNonZeros only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), + Seq(Row(exp.numNonZeros), exp.numNonZeros)) + } + + registerTest(s"$name - numNonZeros only (direct)") { + val (df, c) = wrappedInit() + compare(df.select(numNonZeros(c)), + Seq(exp.numNonZeros)) + } + + registerTest(s"$name - min only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("min").summary(c), min(c)), + Seq(Row(exp.min), exp.min)) + } + + registerTest(s"$name - max only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("max").summary(c), max(c)), + Seq(Row(exp.max), exp.max)) + } + + registerTest(s"$name - normL1 only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("normL1").summary(c), normL1(c)), + Seq(Row(exp.normL1), exp.normL1)) + } + + registerTest(s"$name - normL2 only") { + val (df, c) = wrappedInit() + compare(df.select(metrics("normL2").summary(c), normL2(c)), + Seq(Row(exp.normL2), exp.normL2)) + } + + registerTest(s"$name - all metrics at once") { + val (df, c) = wrappedInit() + compare(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c), + mean(c), variance(c), count(c), numNonZeros(c)), + Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros), + exp.mean, exp.variance, exp.count, exp.numNonZeros)) + } + } + + private def denseData(input: Seq[Seq[Double]]): DataFrame = { + input.map(_.toArray).map(Vectors.dense).map(Tuple1.apply).toDF("features") + } + + private def compare(df: DataFrame, exp: Seq[Any]): Unit = { + val coll = df.collect().toSeq + val Seq(row) = coll + val res = row.toSeq + val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" } + assert(res.size === exp.size, (res.size, exp.size)) + for (((x1, x2), name) <- res.zip(exp).zip(names)) { + compareStructures(x1, x2, name) + } + } + + // Compares structured content. + private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match { + case (y1: Seq[Double @unchecked], v1: OldVector) => + compareStructures(y1, v1.toArray.toSeq, name) + case (d1: Double, d2: Double) => + assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name) + case (r1: GenericRowWithSchema, r2: Row) => + assert(r1.size === r2.size, (r1, r2)) + for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) { + compareStructures(x1, x2, s"$name.$fname") + } + case (r1: Row, r2: Row) => + assert(r1.size === r2.size, (r1, r2)) + for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) } + case (v1: Vector, v2: Vector) => + assert2(v1 ~== v2 absTol 1e-4, name) + case (l1: Long, l2: Long) => assert(l1 === l2) + case (s1: Seq[_], s2: Seq[_]) => + assert(s1.size === s2.size, s"$name ${(s1, s2)}") + for (((x1, idx), x2) <- s1.zipWithIndex.zip(s2)) { + compareStructures(x1, x2, s"$name.$idx") + } + case (arr1: Array[_], arr2: Array[_]) => + assert(arr1.toSeq === arr2.toSeq) + case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2") + } + + private def assert2(x: => Boolean, hint: String): Unit = { + try { + assert(x, hint) + } catch { + case tfe: TestFailedException => + throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1) + } + } + + test("debugging test") { + val df = denseData(Nil) + val c = df.col("features") + val c1 = metrics("mean").summary(c) + val res = df.select(c1) + intercept[SparkException] { + compare(res, Seq.empty) + } + } + + test("basic error handling") { + val df = denseData(Nil) + val c = df.col("features") + val res = df.select(metrics("mean").summary(c), mean(c)) + intercept[SparkException] { + compare(res, Seq.empty) + } + } + + test("no element, working metrics") { + val df = denseData(Nil) + val c = df.col("features") + val res = df.select(metrics("count").summary(c), count(c)) + compare(res, Seq(Row(0L), 0L)) + } + + val singleElem = Seq(0.0, 1.0, 2.0) + testExample("single element", Seq(singleElem), ExpectedMetrics( + mean = singleElem, + variance = Seq(0.0, 0.0, 0.0), + count = 1, + numNonZeros = Seq(0, 1, 1), + max = singleElem, + min = singleElem, + normL1 = singleElem, + normL2 = singleElem + )) + + testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics( + mean = Seq(0.0, 0.0, 0.0), + // TODO: I have a doubt about these values, they are not normalized. + variance = Seq(0.0, 2.0, 8.0), + count = 2, + numNonZeros = Seq(0, 2, 2), + max = Seq(0.0, 1.0, 2.0), + min = Seq(0.0, -1.0, -2.0), + normL1 = Seq(0.0, 2.0, 4.0), + normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0) + )) + + testExample("dense vector input", + Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)), + ExpectedMetrics( + mean = Seq(1.0, -1.5, 3.0), + variance = Seq(8.0, 4.5, 18.0), + count = 2, + numNonZeros = Seq(2, 1, 1), + max = Seq(3.0, 0.0, 6.0), + min = Seq(-1.0, -3, 0.0), + normL1 = Seq(4.0, 3.0, 6.0), + normL2 = Seq(math.sqrt(10), 3, 6.0) + ) + ) + + test("summarizer buffer basic error handing") { + val summarizer = new SummarizerBuffer + + assert(summarizer.count === 0, "should be zero since nothing is added.") + + withClue("Getting numNonzeros from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.numNonzeros + } + } + + withClue("Getting variance from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.variance + } + } + + withClue("Getting mean from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.mean + } + } + + withClue("Getting max from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.max + } + } + + withClue("Getting min from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.min + } + } + + summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0)))) + + withClue("Adding a new dense sample with different array size should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.add(Vectors.dense(3.0, 1.0)) + } + } + + withClue("Adding a new sparse sample with different array size should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0)))) + } + } + + val summarizer2 = (new SummarizerBuffer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0)) + withClue("Merging a new summarizer with different dimensions should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.merge(summarizer2) + } + } + } + + test("summarizer buffer dense vector input") { + // For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all + // the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we + // need to test as well. + val summarizer = (new SummarizerBuffer) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)) + + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") + assert(summarizer.count === 2) + } + + test("summarizer buffer sparse vector input") { + val summarizer = (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) + .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) + + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") + assert(summarizer.count === 2) + } + + test("summarizer buffer mixing dense and sparse vector input") { + val summarizer = (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3)))) + .add(Vectors.dense(0.0, -1.0, -3.0)) + .add(Vectors.sparse(3, Seq((1, -5.1)))) + .add(Vectors.dense(3.8, 0.0, 1.9)) + .add(Vectors.dense(1.7, -0.6, 0.0)) + .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) + + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") + + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") + + assert(summarizer.count === 6) + } + + test("summarizer buffer merging two summarizers") { + val summarizer1 = (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3)))) + .add(Vectors.dense(0.0, -1.0, -3.0)) + + val summarizer2 = (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((1, -5.1)))) + .add(Vectors.dense(3.8, 0.0, 1.9)) + .add(Vectors.dense(1.7, -0.6, 0.0)) + .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) + + val summarizer = summarizer1.merge(summarizer2) + + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") + + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") + assert(summarizer.count === 6) + } + + test("summarizer buffer merging summarizer with empty summarizer") { + // If one of two is non-empty, this should return the non-empty summarizer. + // If both of them are empty, then just return the empty summarizer. + val summarizer1 = (new SummarizerBuffer) + .add(Vectors.dense(0.0, -1.0, -3.0)).merge(new SummarizerBuffer) + assert(summarizer1.count === 1) + + val summarizer2 = (new SummarizerBuffer) + .merge((new SummarizerBuffer).add(Vectors.dense(0.0, -1.0, -3.0))) + assert(summarizer2.count === 1) + + val summarizer3 = (new SummarizerBuffer).merge(new SummarizerBuffer) + assert(summarizer3.count === 0) + + assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") + assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") + assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") + assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") + assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") + assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") + assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") + assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") + assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") + } + + test("summarizer buffer merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new SummarizerBuffer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new SummarizerBuffer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } + + test("summarizer buffer merging summarizer with weighted samples") { + val summarizer = (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new SummarizerBuffer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(3.0, 4.0, 3.0)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } + + test("summarizer buffer test min/max with weighted samples") { + val summarizer1 = new SummarizerBuffer() + .add(Vectors.dense(10.0, -10.0), 1e10) + .add(Vectors.dense(0.0, 0.0), 1e-7) + + val summarizer2 = new SummarizerBuffer() + summarizer2.add(Vectors.dense(10.0, -10.0), 1e10) + for (i <- 1 to 100) { + summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7) + } + + val summarizer3 = new SummarizerBuffer() + for (i <- 1 to 100) { + summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7) + } + summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + } + + ignore("performance test") { + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12 + MacBook Pro (15-inch, 2016) CPU 2.9 GHz Intel Core i7 + Use 2 partitions. tries out times= 20, warm up times = 10 + + The unit of test results is records/milliseconds (higher is better) + + Vector size/records number: 1/1E7 10/1E6 100/1E6 1E3/1E5 1E4/1E4 + ----------------------------------------------------------------------------- + DataFrame 15149 7441 2118 224 21 + RDD from DataFrame 4992 4440 2328 320 33 + Raw RDD 53931 20683 3966 528 53 + */ + import scala.util.Random + val rand = new Random() + + val genArr = (dim: Int) => { + Array.fill(dim)(rand.nextDouble()) + } + + val numPartitions = 2 + for ( (n, dim) <- Seq( + (10000000, 1), (1000000, 10), (1000000, 100), (100000, 1000), (10000, 10000)) + ) { + val rdd1 = sc.parallelize(1 to n, numPartitions).map { idx => + OldVectors.dense(genArr(dim)) + } + // scalastyle:off println + println(s"records number = $n, vector size = $dim, partition = ${rdd1.getNumPartitions}") + // scalastyle:on println + + val numOfTry = 20 + val numOfWarmUp = 10 + rdd1.cache() + rdd1.count() + val rdd2 = sc.parallelize(1 to n, numPartitions).map { idx => + Vectors.dense(genArr(dim)) + } + rdd2.cache() + rdd2.count() + val df = rdd2.map(Tuple1.apply).toDF("features") + df.cache() + df.count() + + def print(name: String, l: List[Long]): Unit = { + def f(z: Long) = (1e6 * n.toDouble) / z + val min = f(l.max) + val max = f(l.min) + val med = f(l.sorted.drop(l.size / 2).head) + // scalastyle:off println + println(s"$name = [$min ~ $med ~ $max] records / milli") + // scalastyle:on println + } + + var timeDF: List[Long] = Nil + val x = df.select( + metrics("mean", "variance", "count", "numNonZeros", "max", "min", "normL1", + "normL2").summary($"features")) + for (i <- 1 to numOfTry) { + val start = System.nanoTime() + x.head() + val end = System.nanoTime() + if (i > numOfWarmUp) timeDF ::= (end - start) + } + + var timeRDD: List[Long] = Nil + for (i <- 1 to numOfTry) { + val start = System.nanoTime() + Statistics.colStats(rdd1) + val end = System.nanoTime() + if (i > numOfWarmUp) timeRDD ::= (end - start) + } + + var timeRDDFromDF: List[Long] = Nil + val rddFromDf = df.rdd.map { case Row(v: Vector) => OldVectors.fromML(v) } + for (i <- 1 to numOfTry) { + val start = System.nanoTime() + Statistics.colStats(rddFromDf) + val end = System.nanoTime() + if (i > numOfWarmUp) timeRDDFromDF ::= (end - start) + } + + print("DataFrame : ", timeDF) + print("RDD :", timeRDD) + print("RDD from DataFrame : ", timeRDDFromDF) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 7c57025f995d6e94e49dbae1b0de0b94b14da2c1..64b94f0a2c103114114d386e6930de287b23ff1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -101,6 +101,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu /** * A projection that returns UnsafeRow. + * + * CAUTION: the returned projection object should *not* be assumed to be thread-safe. */ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow @@ -110,11 +112,15 @@ object UnsafeProjection { /** * Returns an UnsafeProjection for given StructType. + * + * CAUTION: the returned projection object is *not* thread-safe. */ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) /** * Returns an UnsafeProjection for given Array of DataTypes. + * + * CAUTION: the returned projection object is *not* thread-safe. */ def create(fields: Array[DataType]): UnsafeProjection = { create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 7af49014358570c4a3b26f00519bb532fcb5b13a..19abce01a26cf162a5b5339c3170310735825707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -511,6 +511,12 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * Generates the final aggregation result value for current key group with the aggregation buffer * object. * + * Developer note: the only return types accepted by Spark are: + * - primitive types + * - InternalRow and subclasses + * - ArrayData + * - MapData + * * @param buffer aggregation buffer object. * @return The aggregation result of current key group */