diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d29c27c14b0c3ff79829644dae176562980c5fa2..fa09f821fc9977587318cde8910e7aca117b5ad3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -270,6 +270,63 @@ class CodegenContext {
     }
   }
 
+  /**
+   * Returns the specialized code to set a given value in a column vector for a given `DataType`.
+   */
+  def setValue(batch: String, row: String, dataType: DataType, ordinal: Int,
+      value: String): String = {
+    val jt = javaType(dataType)
+    dataType match {
+      case _ if isPrimitiveType(jt) =>
+        s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);"
+      case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});"
+      case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());"
+      case _ =>
+        throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+    }
+  }
+
+  /**
+   * Returns the specialized code to set a given value in a column vector for a given `DataType`
+   * that could potentially be nullable.
+   */
+  def updateColumn(
+      batch: String,
+      row: String,
+      dataType: DataType,
+      ordinal: Int,
+      ev: ExprCode,
+      nullable: Boolean): String = {
+    if (nullable) {
+      s"""
+         if (!${ev.isNull}) {
+           ${setValue(batch, row, dataType, ordinal, ev.value)}
+         } else {
+           $batch.column($ordinal).putNull($row);
+         }
+       """
+    } else {
+      s"""${setValue(batch, row, dataType, ordinal, ev.value)};"""
+    }
+  }
+
+  /**
+   * Returns the specialized code to access a value from a column vector for a given `DataType`.
+   */
+  def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = {
+    val jt = javaType(dataType)
+    dataType match {
+      case _ if isPrimitiveType(jt) =>
+        s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)"
+      case t: DecimalType =>
+        s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})"
+      case StringType =>
+        s"$batch.column($ordinal).getUTF8String($row)"
+      case _ =>
+        throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+    }
+  }
+
   /**
    * Returns the name used in accessor and setter for a Java primitive type.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index d4cef8f310dac34bf61c5c91eaa0b905d253d3c2..5c0fc02861b1e4767e1e31fe33a93d42359706c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{LongType, StructType}
+import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
 import org.apache.spark.unsafe.KVIterator
 
 case class TungstenAggregate(
@@ -265,11 +265,7 @@ case class TungstenAggregate(
 
   // The name for Vectorized HashMap
   private var vectorizedHashMapTerm: String = _
-
-  // We currently only enable vectorized hashmap for long key/value types and partial aggregates
-  private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled &&
-    (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) &&
-    modes.forall(mode => mode == Partial || mode == PartialMerge)
+  private var isVectorizedHashMapEnabled: Boolean = _
 
   // The name for UnsafeRow HashMap
   private var hashMapTerm: String = _
@@ -447,10 +443,16 @@ case class TungstenAggregate(
     val initAgg = ctx.freshName("initAgg")
     ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
 
+    // Enable vectorized hash map for all primitive data types during partial aggregation
+    isVectorizedHashMapEnabled = sqlContext.conf.columnarAggregateMapEnabled &&
+      (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
+        f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
+      bufferSchema.forall(!_.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty &&
+      modes.forall(mode => mode == Partial || mode == PartialMerge)
     vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap")
     val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap")
-    val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName,
-      groupingKeySchema, bufferSchema)
+    val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
+      vectorizedHashMapClassName, groupingKeySchema, bufferSchema)
     // Create a name for iterator from vectorized HashMap
     val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter")
     if (isVectorizedHashMapEnabled) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index dd9b2f097e121b5118003dc987a2817814e892c0..61bd6eb3cde66500b23c711d144c7594a35cf000 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -17,8 +17,9 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types._
 
 /**
  * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache'
@@ -40,12 +41,32 @@ import org.apache.spark.sql.types.StructType
  */
 class VectorizedHashMapGenerator(
     ctx: CodegenContext,
+    aggregateExpressions: Seq[AggregateExpression],
     generatedClassName: String,
     groupingKeySchema: StructType,
     bufferSchema: StructType) {
-  val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
-  val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
-  val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
+  case class Buffer(dataType: DataType, name: String)
+  val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
+  val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
+  val groupingKeySignature =
+    groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
+  val buffVars: Seq[ExprCode] = {
+    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    val initExpr = functions.flatMap(f => f.initialValues)
+    initExpr.map { e =>
+      val isNull = ctx.freshName("bufIsNull")
+      val value = ctx.freshName("bufValue")
+      ctx.addMutableState("boolean", isNull, "")
+      ctx.addMutableState(ctx.javaType(e.dataType), value, "")
+      val ev = e.genCode(ctx)
+      val initVars =
+        s"""
+           | $isNull = ${ev.isNull};
+           | $value = ${ev.value};
+       """.stripMargin
+      ExprCode(ev.code + initVars, isNull, value)
+    }
+  }
 
   def generate(): String = {
     s"""
@@ -67,20 +88,28 @@ class VectorizedHashMapGenerator(
 
   private def initializeAggregateHashMap(): String = {
     val generatedSchema: String =
-      s"""
-         |new org.apache.spark.sql.types.StructType()
-         |${(groupingKeySchema ++ bufferSchema).map(key =>
-          s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
-          .mkString("\n")};
-      """.stripMargin
+      s"new org.apache.spark.sql.types.StructType()" +
+        (groupingKeySchema ++ bufferSchema).map { key =>
+          key.dataType match {
+            case d: DecimalType =>
+              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+                  |${d.precision}, ${d.scale}))""".stripMargin
+            case _ =>
+              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+          }
+        }.mkString("\n").concat(";")
 
     val generatedAggBufferSchema: String =
-      s"""
-         |new org.apache.spark.sql.types.StructType()
-         |${bufferSchema.map(key =>
-        s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
-        .mkString("\n")};
-      """.stripMargin
+      s"new org.apache.spark.sql.types.StructType()" +
+        bufferSchema.map { key =>
+          key.dataType match {
+            case d: DecimalType =>
+              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+                  |${d.precision}, ${d.scale}))""".stripMargin
+            case _ =>
+              s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+          }
+        }.mkString("\n").concat(";")
 
     s"""
        |  private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
@@ -122,12 +151,23 @@ class VectorizedHashMapGenerator(
    * }}}
    */
   private def generateHashFunction(): String = {
+    val hash = ctx.freshName("hash")
+
+    def genHashForKeys(groupingKeys: Seq[Buffer]): String = {
+      groupingKeys.map { key =>
+        val result = ctx.freshName("result")
+        s"""
+           |${genComputeHash(ctx, key.name, key.dataType, result)}
+           |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2);
+          """.stripMargin
+      }.mkString("\n")
+    }
+
     s"""
        |private long hash($groupingKeySignature) {
-       |  long h = 0;
-       |  ${groupingKeys.map(key => s"h = (h ^ (0x9e3779b9)) + ${key._2} + (h << 6) + (h >>> 2);")
-            .mkString("\n")}
-       |  return h;
+       |  long $hash = 0;
+       |  ${genHashForKeys(groupingKeys)}
+       |  return $hash;
        |}
      """.stripMargin
   }
@@ -145,10 +185,17 @@ class VectorizedHashMapGenerator(
    * }}}
    */
   private def generateEquals(): String = {
+
+    def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
+      groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
+        s"""(${ctx.genEqual(key.dataType, ctx.getValue("batch", "buckets[idx]",
+          key.dataType, ordinal), key.name)})"""
+      }.mkString(" && ")
+    }
+
     s"""
        |private boolean equals(int idx, $groupingKeySignature) {
-       |  return ${groupingKeys.zipWithIndex.map(k =>
-            s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
+       |  return ${genEqualsForKeys(groupingKeys)};
        |}
      """.stripMargin
   }
@@ -187,21 +234,39 @@ class VectorizedHashMapGenerator(
    * }}}
    */
   private def generateFindOrInsert(): String = {
+
+    def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
+      groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
+        ctx.setValue("batch", "numRows", key.dataType, ordinal, key.name)
+      }
+    }
+
+    def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
+      bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
+        ctx.updateColumn("batch", "numRows", key.dataType, groupingKeys.length + ordinal,
+          buffVars(ordinal), nullable = true)
+      }
+    }
+
     s"""
        |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
             groupingKeySignature}) {
-       |  long h = hash(${groupingKeys.map(_._2).mkString(", ")});
+       |  long h = hash(${groupingKeys.map(_.name).mkString(", ")});
        |  int step = 0;
        |  int idx = (int) h & (numBuckets - 1);
        |  while (step < maxSteps) {
        |    // Return bucket index if it's either an empty slot or already contains the key
        |    if (buckets[idx] == -1) {
        |      if (numRows < capacity) {
-       |        ${groupingKeys.zipWithIndex.map(k =>
-                  s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
-       |        ${bufferValues.zipWithIndex.map(k =>
-                  s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);")
-                  .mkString("\n")}
+       |
+       |        // Initialize aggregate keys
+       |        ${genCodeToSetKeys(groupingKeys).mkString("\n")}
+       |
+       |        ${buffVars.map(_.code).mkString("\n")}
+       |
+       |        // Initialize aggregate values
+       |        ${genCodeToSetAggBuffers(bufferValues).mkString("\n")}
+       |
        |        buckets[idx] = numRows++;
        |        batch.setNumRows(numRows);
        |        aggregateBufferBatch.setNumRows(numRows);
@@ -210,7 +275,7 @@ class VectorizedHashMapGenerator(
        |        // No more space
        |        return null;
        |      }
-       |    } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
+       |    } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
        |      return aggregateBufferBatch.getRow(buckets[idx]);
        |    }
        |    idx = (idx + 1) & (numBuckets - 1);
@@ -238,4 +303,42 @@ class VectorizedHashMapGenerator(
        |}
      """.stripMargin
   }
+
+  private def genComputeHash(
+      ctx: CodegenContext,
+      input: String,
+      dataType: DataType,
+      result: String): String = {
+    def hashInt(i: String): String = s"int $result = $i;"
+    def hashLong(l: String): String = s"long $result = $l;"
+    def hashBytes(b: String): String = {
+      val hash = ctx.freshName("hash")
+      s"""
+         |int $result = 0;
+         |for (int i = 0; i < $b.length; i++) {
+         |  ${genComputeHash(ctx, s"$b[i]", ByteType, hash)}
+         |  $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2);
+         |}
+       """.stripMargin
+    }
+
+    dataType match {
+      case BooleanType => hashInt(s"$input ? 1 : 0")
+      case ByteType | ShortType | IntegerType | DateType => hashInt(input)
+      case LongType | TimestampType => hashLong(input)
+      case FloatType => hashInt(s"Float.floatToIntBits($input)")
+      case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
+      case d: DecimalType =>
+        if (d.precision <= Decimal.MAX_LONG_DIGITS) {
+          hashLong(s"$input.toUnscaledLong()")
+        } else {
+          val bytes = ctx.freshName("bytes")
+          s"""
+            final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+            ${hashBytes(bytes)}
+          """
+        }
+      case StringType => hashBytes(s"$input.getBytes()")
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index a4e82d80f50f5b75e5897b61f81dea26a7115a3f..eb976fbaad3e8d961fc939f7e8be3cc669f17ec7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -482,7 +482,7 @@ object SQLConf {
     .internal()
     .doc("When true, aggregate with keys use an in-memory columnar map to speed up execution.")
     .booleanConf
-    .createWithDefault(false)
+    .createWithDefault(true)
 
   val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion")
     .internal()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 3fb70f2eb6ae3cb881a1baa4bd3129d7a3984a91..7a120b93749c986005e3d304186721ecbab0b7c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -224,6 +224,127 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     */
   }
 
+  ignore("aggregate with string key") {
+    val N = 20 << 20
+
+    val benchmark = new Benchmark("Aggregate w string key", N)
+    def f(): Unit = sqlContext.range(N).selectExpr("id", "cast(id & 1023 as string) as k")
+      .groupBy("k").count().collect()
+
+    benchmark.addCase(s"codegen = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = T") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true")
+      f()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    Aggregate w string key:             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    codegen = F                              3307 / 3376          6.3         157.7       1.0X
+    codegen = T hashmap = F                  2364 / 2471          8.9         112.7       1.4X
+    codegen = T hashmap = T                  1740 / 1841         12.0          83.0       1.9X
+    */
+  }
+
+  ignore("aggregate with decimal key") {
+    val N = 20 << 20
+
+    val benchmark = new Benchmark("Aggregate w decimal key", N)
+    def f(): Unit = sqlContext.range(N).selectExpr("id", "cast(id & 65535 as decimal) as k")
+      .groupBy("k").count().collect()
+
+    benchmark.addCase(s"codegen = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = T") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true")
+      f()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    Aggregate w decimal key:             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    codegen = F                              2756 / 2817          7.6         131.4       1.0X
+    codegen = T hashmap = F                  1580 / 1647         13.3          75.4       1.7X
+    codegen = T hashmap = T                   641 /  662         32.7          30.6       4.3X
+    */
+  }
+
+  ignore("aggregate with multiple key types") {
+    val N = 20 << 20
+
+    val benchmark = new Benchmark("Aggregate w multiple keys", N)
+    def f(): Unit = sqlContext.range(N)
+      .selectExpr(
+        "id",
+        "(id & 1023) as k1",
+        "cast(id & 1023 as string) as k2",
+        "cast(id & 1023 as int) as k3",
+        "cast(id & 1023 as double) as k4",
+        "cast(id & 1023 as float) as k5",
+        "id > 1023 as k6")
+      .groupBy("k1", "k2", "k3", "k4", "k5", "k6")
+      .sum()
+      .collect()
+
+    benchmark.addCase(s"codegen = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = F") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false")
+      f()
+    }
+
+    benchmark.addCase(s"codegen = T hashmap = T") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true")
+      f()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    Aggregate w decimal key:             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    codegen = F                              5885 / 6091          3.6         280.6       1.0X
+    codegen = T hashmap = F                  3625 / 4009          5.8         172.8       1.6X
+    codegen = T hashmap = T                  3204 / 3271          6.5         152.8       1.8X
+    */
+  }
+
   ignore("broadcast hash join") {
     val N = 20 << 20
     val M = 1 << 16