diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
new file mode 100644
index 0000000000000000000000000000000000000000..46514ae5f0e84fcd8284fe8a3d0f229de8429372
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+/**
+ * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
+ * at most a single one-value. By default, the binary vector has an element for each category, so
+ * with 5 categories, an input value of 2.0 would map to an output vector of
+ * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
+ * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
+ * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
+ * linearly dependent because they sum up to one.
+ */
+@AlphaComponent
+class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
+  with HasInputCol with HasOutputCol {
+
+  /**
+   * Whether to include a component in the encoded vectors for the first category, defaults to true.
+   * @group param
+   */
+  final val includeFirst: BooleanParam =
+    new BooleanParam(this, "includeFirst", "include first category")
+  setDefault(includeFirst -> true)
+
+  private var categories: Array[String] = _
+
+  /** @group setParam */
+  def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+
+  /** @group setParam */
+  override def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  override def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    val inputFields = schema.fields
+    val outputColName = $(outputCol)
+    require(inputFields.forall(_.name != $(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+
+    val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
+    categories = inputColAttr match {
+      case nominal: NominalAttribute =>
+        nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
+      case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+      case _ =>
+        throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+    }
+
+    val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
+    val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
+    val outputFields = inputFields :+ attr.toStructField()
+    StructType(outputFields)
+  }
+
+  protected override def createTransformFunc(): (Double) => Vector = {
+    val first = $(includeFirst)
+    val vecLen = if (first) categories.length else categories.length - 1
+    val oneValue = Array(1.0)
+    val emptyValues = Array[Double]()
+    val emptyIndices = Array[Int]()
+    label: Double => {
+      val values = if (first || label != 0.0) oneValue else emptyValues
+      val indices = if (first) {
+        Array(label.toInt)
+      } else if (label != 0.0) {
+        Array(label.toInt - 1)
+      } else {
+        emptyIndices
+      }
+      Vectors.sparse(vecLen, indices, values)
+    }
+  }
+
+  /**
+   * Returns the data type of the output column.
+   */
+  protected def outputDataType: DataType = new VectorUDT
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..92ec407b98d69e1b3d10eabbe0dcbba4747f8b38
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+
+class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+  private var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+  }
+
+  def stringIndexed(): DataFrame = {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    indexer.transform(df)
+  }
+
+  test("OneHotEncoder includeFirst = true") {
+    val transformed = stringIndexed()
+    val encoder = new OneHotEncoder()
+      .setInputCol("labelIndex")
+      .setOutputCol("labelVec")
+    val encoded = encoder.transform(transformed)
+
+    val output = encoded.select("id", "labelVec").map { r =>
+      val vec = r.get(1).asInstanceOf[Vector]
+      (r.getInt(0), vec(0), vec(1), vec(2))
+    }.collect().toSet
+    // a -> 0, b -> 2, c -> 1
+    val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
+      (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
+    assert(output === expected)
+  }
+
+  test("OneHotEncoder includeFirst = false") {
+    val transformed = stringIndexed()
+    val encoder = new OneHotEncoder()
+      .setIncludeFirst(false)
+      .setInputCol("labelIndex")
+      .setOutputCol("labelVec")
+    val encoded = encoder.transform(transformed)
+
+    val output = encoded.select("id", "labelVec").map { r =>
+      val vec = r.get(1).asInstanceOf[Vector]
+      (r.getInt(0), vec(0), vec(1))
+    }.collect().toSet
+    // a -> 0, b -> 2, c -> 1
+    val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
+      (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+    assert(output === expected)
+  }
+
+}