Skip to content
Snippets Groups Projects
Commit b0cafdb6 authored by Nick Pentreath's avatar Nick Pentreath
Browse files

[MINOR][ML][PYSPARK] ALS example cleanup

Cleans up ALS examples by removing unnecessary casts to double for `rating` and `prediction` columns, since `RegressionEvaluator` now supports `Double` & `Float` input types.

## How was this patch tested?

Manual compile and run with `run-example ml.ALSExample` and `spark-submit examples/src/main/python/ml/als_example.py`.

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #12892 from MLnick/als-examples-cleanup.
parent df89f1d4
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,6 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.types.DataTypes;
// $example off$
public class JavaALSExample {
......@@ -109,10 +108,7 @@ public class JavaALSExample {
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
Dataset<Row> rawPredictions = model.transform(test);
Dataset<Row> predictions = rawPredictions
.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
......
......@@ -48,12 +48,9 @@ if __name__ == "__main__":
model = als.fit(training)
# Evaluate the model by computing the RMSE on the test data
rawPredictions = model.transform(test)
predictions = rawPredictions\
.withColumn("rating", rawPredictions.rating.cast("double"))\
.withColumn("prediction", rawPredictions.prediction.cast("double"))
evaluator =\
RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
predictions = model.transform(test)
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
# $example off$
......
......@@ -23,10 +23,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
// $example off$
import org.apache.spark.sql.SparkSession
// $example on$
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
// $example off$
object ALSExample {
......@@ -65,8 +61,6 @@ object ALSExample {
// Evaluate the model by computing the RMSE on the test data
val predictions = model.transform(test)
.withColumn("rating", col("rating").cast(DoubleType))
.withColumn("prediction", col("prediction").cast(DoubleType))
val evaluator = new RegressionEvaluator()
.setMetricName("rmse")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment