diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 4b13ba6f9cea3b41cb599fcef2ea3dd17a042010..7f568f4e0db4e73b184b3f23a4fe00f1420cde35 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -29,7 +29,6 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.recommendation.ALS; import org.apache.spark.ml.recommendation.ALSModel; -import org.apache.spark.sql.types.DataTypes; // $example off$ public class JavaALSExample { @@ -109,10 +108,7 @@ public class JavaALSExample { ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data - Dataset<Row> rawPredictions = model.transform(test); - Dataset<Row> predictions = rawPredictions - .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) - .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); + Dataset<Row> predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator() .setMetricName("rmse") diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index ff0829b0dd45a6b50ff4dfe6152340917beb017f..1a979ff5b5be287bbc3d934f00be42f062d0f841 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -48,12 +48,9 @@ if __name__ == "__main__": model = als.fit(training) # Evaluate the model by computing the RMSE on the test data - rawPredictions = model.transform(test) - predictions = rawPredictions\ - .withColumn("rating", rawPredictions.rating.cast("double"))\ - .withColumn("prediction", rawPredictions.prediction.cast("double")) - evaluator =\ - RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction") + predictions = model.transform(test) + evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", + predictionCol="prediction") rmse = evaluator.evaluate(predictions) print("Root-mean-square error = " + str(rmse)) # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 7c1cfe293717aa0d6becacbe8ee201eee1c3a838..6b151a622e2677dbef13c0738cda84a0fad160e1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -23,10 +23,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.recommendation.ALS // $example off$ import org.apache.spark.sql.SparkSession -// $example on$ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType -// $example off$ object ALSExample { @@ -65,8 +61,6 @@ object ALSExample { // Evaluate the model by computing the RMSE on the test data val predictions = model.transform(test) - .withColumn("rating", col("rating").cast(DoubleType)) - .withColumn("prediction", col("prediction").cast(DoubleType)) val evaluator = new RegressionEvaluator() .setMetricName("rmse")