diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index ce43a450daad0321cbb4248ba6ccd7770a5e375e..e479f169021d8347ba5871b0396afd8da86bac20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
 import scala.annotation.varargs
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
+import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
 
 /**
  * :: DeveloperApi ::
@@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
    * Creates an [[Attribute]] from a [[StructField]] instance.
    */
   def fromStructField(field: StructField): Attribute = {
-    require(field.dataType == DoubleType)
+    require(field.dataType.isInstanceOf[NumericType])
     val metadata = field.metadata
     val mlAttr = AttributeKeys.ML_ATTR
     if (metadata.contains(mlAttr)) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 72b575d0225473c4b57c8617b16190609c4c0015..c5fd2f9d5a22a3ca6e0f3eeabba295ff5d1532c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
     assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
     val fldWithMeta = new StructField("x", DoubleType, false, metadata)
     assert(Attribute.fromStructField(fldWithMeta).isNumeric)
+    // Attribute.fromStructField should accept any NumericType, not just DoubleType
+    val longFldWithMeta = new StructField("x", LongType, false, metadata)
+    assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
+    val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
+    assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
   }
 }