From 8d08f43d09157b98e559c0be6ce6fd571a35e0d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 6 Sep 2016 10:36:00 +0800 Subject: [PATCH] [SPARK-17279][SQL] better error message for exceptions during ScalaUDF execution ## What changes were proposed in this pull request? If `ScalaUDF` throws exceptions during executing user code, sometimes it's hard for users to figure out what's wrong, especially when they use Spark shell. An example ``` org.apache.spark.SparkException: Job aborted due to stage failure: Task 12 in stage 325.0 failed 4 times, most recent failure: Lost task 12.3 in stage 325.0 (TID 35622, 10.0.207.202): java.lang.NullPointerException at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40) at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) ... ``` We should catch these exceptions and rethrow them with better error message, to say that the exception is happened in scala udf. This PR also does some clean up for `ScalaUDF` and add a unit test suite for it. ## How was this patch tested? the new test suite Author: Wenchen Fan <wenchen@databricks.com> Closes #14850 from cloud-fan/npe. --- .../spark/ml/recommendation/ALSSuite.scala | 16 +++---- .../sql/catalyst/expressions/ScalaUDF.scala | 44 +++++++++++------ .../catalyst/expressions/ScalaUDFSuite.scala | 48 +++++++++++++++++++ 3 files changed, 86 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e8ed50acf8..d0aa2cdfe0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -510,18 +510,18 @@ class ALSSuite (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") withClue("fit should fail when ids exceed integer range. ") { - assert(intercept[IllegalArgumentException] { + assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains("was out of Integer range")) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 21390644bc..6cfdea9fdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType @@ -994,20 +995,15 @@ case class ScalaUDF( ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.references += this - - val scalaUDFClassName = classOf[ScalaUDF].getName + val scalaUDF = ctx.addReferenceObj("scalaUDF", this) val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - val catalystConverterTermIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, catalystConverterTerm, s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter((($scalaUDFClassName)references" + - s"[$catalystConverterTermIdx]).dataType());") + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,10 +1015,8 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - val funcExpressionIdx = ctx.references.size - 1 ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" + - s"[$funcExpressionIdx]).userDefinedFunc());") + s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1039,9 +1033,16 @@ case class ScalaUDF( (convert, argTerm) }.unzip - val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + - s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + - s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val callFunc = + s""" + ${ctx.boxedType(dataType)} $resultTerm = null; + try { + $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + } catch (Exception e) { + throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + } + """ ev.copy(code = s""" $evalCode @@ -1057,5 +1058,20 @@ case class ScalaUDF( private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - override def eval(input: InternalRow): Any = converter(f(input)) + lazy val udfErrorMessage = { + val funcCls = function.getClass.getSimpleName + val inputTypes = children.map(_.dataType.simpleString).mkString(", ") + s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + } + + override def eval(input: InternalRow): Any = { + val result = try { + f(input) + } catch { + case e: Exception => + throw new SparkException(udfErrorMessage, e) + } + + converter(result) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala new file mode 100644 index 0000000000..7e45028653 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil) + checkEvaluation(intUdf, 2) + + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil) + checkEvaluation(stringUdf, "ax") + } + + test("better error message for NPE") { + val udf = ScalaUDF( + (s: String) => s.toLowerCase, + StringType, + Literal.create(null, StringType) :: Nil) + + val e1 = intercept[SparkException](udf.eval()) + assert(e1.getMessage.contains("Failed to execute user defined function")) + + val e2 = intercept[SparkException] { + checkEvalutionWithUnsafeProjection(udf, null) + } + assert(e2.getMessage.contains("Failed to execute user defined function")) + } + +} -- GitLab