diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index ca900536bc7b8d1e33cfd495db49ab7f4cef1ab6..73f27d1a423d9d32fbae35320c64e9fc8c07310f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -113,12 +113,15 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColName = $(outputCol) - val inputDataTypes = inputColNames.map(name => schema(name).dataType) - inputDataTypes.foreach { - case _: NumericType | BooleanType => - case t if t.isInstanceOf[VectorUDT] => - case other => - throw new IllegalArgumentException(s"Data type $other is not supported.") + val incorrectColumns = inputColNames.flatMap { name => + schema(name).dataType match { + case _: NumericType | BooleanType => None + case t if t.isInstanceOf[VectorUDT] => None + case other => Some(s"Data type $other of column $name is not supported.") + } + } + if (incorrectColumns.nonEmpty) { + throw new IllegalArgumentException(incorrectColumns.mkString("\n")) } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 46cced3a9a6e5d501a1f1d10d59ce8fd119092d6..6aef1c683702587b48c74f60877dbd909aa9aa66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -79,7 +79,10 @@ class VectorAssemblerSuite val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "Data type StringType is not supported") + assert(thrown.getMessage contains + "Data type StringType of column a is not supported.\n" + + "Data type StringType of column b is not supported.\n" + + "Data type StringType of column c is not supported.") } test("ML attributes") {