diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 9a56a75b69d0bb88f1b96d360295764d73ad0b79..f6f5281f71a5fcdc142ec1e2b0ae5c7c52c8b9d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -315,9 +315,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame): ALSModel = { import dataset.sqlContext.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), - col($(ratingCol)).cast(FloatType)) + .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) }