diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java index 85529f6a0aa1ec636682148e2632483663900350..a88a315bf479f887ad631cdc165bc92e0ac639da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java @@ -165,10 +165,10 @@ public final class FixedLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatc protected FixedLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema, int maxRows, TaskMemoryManager manager) { super(keySchema, valueSchema, maxRows, manager); - klen = keySchema.defaultSize() - + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length()); - vlen = valueSchema.defaultSize() - + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length()); + int keySize = keySchema.size() * 8; // each fixed-length field is stored in a 8-byte word + int valueSize = valueSchema.size() * 8; + klen = keySize + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length()); + vlen = valueSize + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length()); recordLength = klen + vlen + 8; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bd7efa606e0ceed7b34f4cf1f3780d9a063582f6..59e132dfb252d928f1c9285ee81f75c90ad79056 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -279,9 +280,14 @@ case class HashAggregateExec( .map(_.asInstanceOf[DeclarativeAggregate]) private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - // The name for Vectorized HashMap - private var vectorizedHashMapTerm: String = _ - private var isVectorizedHashMapEnabled: Boolean = _ + // The name for Fast HashMap + private var fastHashMapTerm: String = _ + private var isFastHashMapEnabled: Boolean = false + + // whether a vectorized hashmap is used instead + // we have decided to always use the row-based hashmap, + // but the vectorized hashmap can still be switched on for testing and benchmarking purposes. + private var isVectorizedHashMapEnabled: Boolean = false // The name for UnsafeRow HashMap private var hashMapTerm: String = _ @@ -307,6 +313,16 @@ case class HashAggregateExec( ) } + def getTaskMemoryManager(): TaskMemoryManager = { + TaskContext.get().taskMemoryManager() + } + + def getEmptyAggregationBuffer(): InternalRow = { + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + initialBuffer + } + /** * This is called by generated Java class, should be public. */ @@ -459,52 +475,91 @@ case class HashAggregateExec( } /** - * Using the vectorized hash map in HashAggregate is currently supported for all primitive - * data types during partial aggregation. However, we currently only enable the hash map for a - * subset of cases that've been verified to show performance improvements on our benchmarks - * subject to an internal conf that sets an upper limit on the maximum length of the aggregate - * key/value schema. - * + * A required check for any fast hash map implementation (basically the common requirements + * for row-based and vectorized). + * Currently fast hash map is supported for primitive data types during partial aggregation. * This list of supported use-cases should be expanded over time. */ - private def enableVectorizedHashMap(ctx: CodegenContext): Boolean = { - val schemaLength = (groupingKeySchema ++ bufferSchema).length + private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) - // We do not support byte array based decimal type for aggregate values as - // ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place + // For vectorized hash map, We do not support byte array based decimal type for aggregate values + // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place // updates. Due to this, appending the byte array in the vectorized hash map can turn out to be // quite inefficient and can potentially OOM the executor. + // For row-based hash map, while decimal update is supported in UnsafeRow, we will just act + // conservative here, due to lack of testing and benchmarking. val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType]) .forall(!DecimalType.isByteArrayDecimalType(_)) - isSupported && isNotByteArrayDecimalType && - schemaLength <= sqlContext.conf.vectorizedAggregateMapMaxColumns + isSupported && isNotByteArrayDecimalType + } + + private def enableTwoLevelHashMap(ctx: CodegenContext) = { + if (!checkIfFastHashMapSupported(ctx)) { + if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { + logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but" + + " current version of codegened fast hashmap does not support this aggregate.") + } + } else { + isFastHashMapEnabled = true + + // This is for testing/benchmarking only. + // We enforce to first level to be a vectorized hashmap, instead of the default row-based one. + sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { + case "true" => isVectorizedHashMapEnabled = true + case null | "" | "false" => None } + } } private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") - isVectorizedHashMapEnabled = enableVectorizedHashMap(ctx) - vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap") - val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap") - val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - vectorizedHashMapClassName, groupingKeySchema, bufferSchema) + if (sqlContext.conf.enableTwoLevelAggMap) { + enableTwoLevelHashMap(ctx) + } else { + sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { + case "true" => logWarning("Two level hashmap is disabled but vectorized hashmap is " + + "enabled.") + case null | "" | "false" => None + } + } + fastHashMapTerm = ctx.freshName("fastHashMap") + val fastHashMapClassName = ctx.freshName("FastHashMap") + val fastHashMapGenerator = + if (isVectorizedHashMapEnabled) { + new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema) + } else { + new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema) + } + + val thisPlan = ctx.addReferenceObj("plan", this) + // Create a name for iterator from vectorized HashMap - val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter") - if (isVectorizedHashMapEnabled) { - ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm, - s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();") - ctx.addMutableState( - "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>", - iterTermForVectorizedHashMap, "") + val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") + if (isFastHashMapEnabled) { + if (isVectorizedHashMapEnabled) { + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, + s"$fastHashMapTerm = new $fastHashMapClassName();") + ctx.addMutableState( + "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>", + iterTermForFastHashMap, "") + } else { + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, + s"$fastHashMapTerm = new $fastHashMapClassName(" + + s"agg_plan.getTaskMemoryManager(), agg_plan.getEmptyAggregationBuffer());") + ctx.addMutableState( + "org.apache.spark.unsafe.KVIterator", + iterTermForFastHashMap, "") + } } // create hashMap - val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName ctx.addMutableState(hashMapClassName, hashMapTerm, "") @@ -518,15 +573,30 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") + + def generateGenerateCode(): String = { + if (isFastHashMapEnabled) { + if (isVectorizedHashMapEnabled) { + s""" + | ${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()} + """.stripMargin + } else { + s""" + | ${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()} + """.stripMargin + } + } else "" + } + ctx.addNewFunction(doAgg, s""" - ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""} + ${generateGenerateCode} private void $doAgg() throws java.io.IOException { $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - ${if (isVectorizedHashMapEnabled) { - s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""} + ${if (isFastHashMapEnabled) { + s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize); } @@ -542,34 +612,56 @@ case class HashAggregateExec( // so `copyResult` should be reset to `false`. ctx.copyResult = false + def outputFromGeneratedMap: String = { + if (isFastHashMapEnabled) { + if (isVectorizedHashMapEnabled) { + outputFromVectorizedMap + } else { + outputFromRowBasedMap + } + } else "" + } + + def outputFromRowBasedMap: String = { + s""" + while ($iterTermForFastHashMap.next()) { + $numOutput.add(1); + UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); + $outputCode + + if (shouldStop()) return; + } + $fastHashMapTerm.close(); + """ + } + // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow - def outputFromGeneratedMap: Option[String] = { - if (isVectorizedHashMapEnabled) { - val row = ctx.freshName("vectorizedHashMapRow") + def outputFromVectorizedMap: String = { + val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null ctx.INPUT_ROW = row var schema: StructType = groupingKeySchema bufferSchema.foreach(i => schema = schema.add(i)) val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) - Option( - s""" - | while ($iterTermForVectorizedHashMap.hasNext()) { - | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) - | $iterTermForVectorizedHashMap.next(); - | ${generateRow.code} - | ${consume(ctx, Seq.empty, {generateRow.value})} - | - | if (shouldStop()) return; - | } - | - | $vectorizedHashMapTerm.close(); - """.stripMargin) - } else None + s""" + | while ($iterTermForFastHashMap.hasNext()) { + | $numOutput.add(1); + | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | $iterTermForFastHashMap.next(); + | ${generateRow.code} + | ${consume(ctx, Seq.empty, {generateRow.value})} + | + | if (shouldStop()) return; + | } + | + | $fastHashMapTerm.close(); + """.stripMargin } + val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" @@ -581,7 +673,7 @@ case class HashAggregateExec( } // output the result - ${outputFromGeneratedMap.getOrElse("")} + ${outputFromGeneratedMap} while ($iterTerm.next()) { $numOutput.add(1); @@ -605,11 +697,11 @@ case class HashAggregateExec( ctx.currentVars = input val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) - val vectorizedRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val fastRowKeys = ctx.generateExpressions( + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") - val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer") + val fastRowBuffer = ctx.freshName("fastAggBuffer") // only have DeclarativeAggregate val updateExpr = aggregateExpressions.flatMap { e => @@ -639,17 +731,18 @@ case class HashAggregateExec( ("true", "true", "", "") } - // We first generate code to probe and update the vectorized hash map. If the probe is - // successful the corresponding vectorized row buffer will hold the mutable row - val findOrInsertInVectorizedHashMap: Option[String] = { - if (isVectorizedHashMapEnabled) { + // We first generate code to probe and update the fast hash map. If the probe is + // successful the corresponding fast row buffer will hold the mutable row + val findOrInsertFastHashMap: Option[String] = { + if (isFastHashMapEnabled) { Option( s""" + | |if ($checkFallbackForGeneratedHashMap) { - | ${vectorizedRowKeys.map(_.code).mkString("\n")} - | if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert( - | ${vectorizedRowKeys.map(_.value).mkString(", ")}); + | ${fastRowKeys.map(_.code).mkString("\n")} + | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); | } |} """.stripMargin) @@ -658,36 +751,35 @@ case class HashAggregateExec( } } - val updateRowInVectorizedHashMap: Option[String] = { - if (isVectorizedHashMapEnabled) { - ctx.INPUT_ROW = vectorizedRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, - isVectorized = true) - } - Option( - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(vectorizedRowEvals)} - |// update vectorized row - |${updateVectorizedRow.mkString("\n").trim} - """.stripMargin) - } else None + + def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { + ctx.INPUT_ROW = fastRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized) + } + Option( + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(fastRowEvals)} + |// update fast row + |${updateFastRow.mkString("\n").trim} + | + """.stripMargin) } // Next, we generate code to probe and update the unsafe row hash map. val findOrInsertInUnsafeRowMap: String = { s""" - | if ($vectorizedRowBuffer == null) { + | if ($fastRowBuffer == null) { | // generate grouping key | ${unsafeRowKeyCode.code.trim} | ${hashEval.code.trim} @@ -745,17 +837,31 @@ case class HashAggregateExec( // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" UnsafeRow $unsafeRowBuffer = null; - org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null; + ${ + if (isVectorizedHashMapEnabled) { + s""" + | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $fastRowBuffer = null; + """.stripMargin + } else { + s""" + | UnsafeRow $fastRowBuffer = null; + """.stripMargin + } + } - ${findOrInsertInVectorizedHashMap.getOrElse("")} + ${findOrInsertFastHashMap.getOrElse("")} $findOrInsertInUnsafeRowMap $incCounter - if ($vectorizedRowBuffer != null) { - // update vectorized row - ${updateRowInVectorizedHashMap.getOrElse("")} + if ($fastRowBuffer != null) { + // update fast row + ${ + if (isFastHashMapEnabled) { + updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("") + } else "" + } } else { // update unsafe row $updateRowInUnsafeRowMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 1dea33037c85cc9d20f98b48bc52a22026448f21..a77e178546ef851caa7d9b2d78720fac7985f890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -141,8 +141,16 @@ class RowBasedHashMapGenerator( } val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"agg_rowWriter.write(${ordinal}, ${key.name})"} - .mkString(";\n") + key.dataType match { + case t: DecimalType => + s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" + case t: DataType => + if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") + } + s"agg_rowWriter.write(${ordinal}, ${key.name})" + } + }.mkString(";\n") s""" |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 91988270ada8df15c55c61f7f15dcdb3ed9159e9..d3440a26441636b4eae5a60b7857e7886ca86d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -509,14 +509,15 @@ object SQLConf { .intConf .createWithDefault(40) - val VECTORIZED_AGG_MAP_MAX_COLUMNS = - SQLConfigBuilder("spark.sql.codegen.aggregate.map.columns.max") + val ENABLE_TWOLEVEL_AGG_MAP = + SQLConfigBuilder("spark.sql.codegen.aggregate.map.twolevel.enable") .internal() - .doc("Sets the maximum width of schema (aggregate keys + values) for which aggregate with" + - "keys uses an in-memory columnar map to speed up execution. Setting this to 0 effectively" + - "disables the columnar map") - .intConf - .createWithDefault(3) + .doc("Enable two-level aggregate hash map. When enabled, records will first be " + + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + + "2nd-level, larger, slower map when 1st level is full or keys cannot be found. " + + "When disabled, records go directly to the 2nd level. Defaults to true.") + .booleanConf + .createWithDefault(true) val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() @@ -687,7 +688,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) - def vectorizedAggregateMapMaxColumns: Int = getConf(VECTORIZED_AGG_MAP_MAX_COLUMNS) + def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..3e85d95523125740606180daba37f65c65aa7381 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfter + +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + + protected override def beforeAll(): Unit = { + sparkConf.set("spark.sql.codegen.fallback", "false") + sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + super.beforeAll() + } + + // adding some checking after each test is run, assuring that the configs are not changed + // in test code + after { + assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + "configuration parameter changed in test body") + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false", + "configuration parameter changed in test body") + } +} + +class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + + protected override def beforeAll(): Unit = { + sparkConf.set("spark.sql.codegen.fallback", "false") + sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + super.beforeAll() + } + + // adding some checking after each test is run, assuring that the configs are not changed + // in test code + after { + assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + "configuration parameter changed in test body") + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + "configuration parameter changed in test body") + } +} + +class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with +BeforeAndAfter { + + protected override def beforeAll(): Unit = { + sparkConf.set("spark.sql.codegen.fallback", "false") + sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") + super.beforeAll() + } + + // adding some checking after each test is run, assuring that the configs are not changed + // in test code + after { + assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + "configuration parameter changed in test body") + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + "configuration parameter changed in test body") + assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", + "configuration parameter changed in test body") + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69a3b5f278fd80d4be31f01b3dacc9dbbc21e9aa..427390a90f1e6bce125fc3ec299dca4797103669 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -485,4 +485,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } + + test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { + checkAnswer( + decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))), + Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)), + Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index bf3a39c84b3b2e5b684942e76275fb7d289ea656..8a2993bdf4b2803e0bc47649e9c82ace9be183f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -106,13 +106,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -146,13 +147,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 0) + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 3) + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -184,13 +186,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -221,13 +224,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -268,13 +272,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "10") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 2dcf13c02a466c94f0dc1b1a5e4e2ca2b4308ac1..4a8086d7e54006cf0149bb898733114b311c50ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -998,9 +998,9 @@ class HashAggregationQuerySuite extends AggregationQuerySuite class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - Seq(0, 10).foreach { maxColumnarHashMapColumns => - withSQLConf("spark.sql.codegen.aggregate.map.columns.max" -> - maxColumnarHashMapColumns.toString) { + Seq("true", "false").foreach { enableTwoLevelMaps => + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" -> + enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {