Skip to content
Snippets Groups Projects
Commit b77a02f4 authored by Vida Ha's avatar Vida Ha Committed by Michael Armbrust
Browse files

[SPARK-3752][SQL]: Add tests for different UDF's

Author: Vida Ha <vida@databricks.com>

Closes #2621 from vidaha/vida/SPARK-3752 and squashes the following commits:

d7fdbbc [Vida Ha] Add tests for different UDF's
parent 73bf3f2e
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution;
import org.apache.hadoop.hive.ql.exec.UDF;
public class UDFIntegerToString extends UDF {
public String evaluate(Integer i) {
return i.toString();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution;
import org.apache.hadoop.hive.ql.exec.UDF;
import java.util.List;
public class UDFListListInt extends UDF {
/**
*
* @param obj
* SQL schema: array<struct<x: int, y: int, z: int>>
* Java Type: List<List<Integer>>
* @return
*/
public long evaluate(Object obj) {
if (obj == null) {
return 0l;
}
List<List> listList = (List<List>) obj;
long retVal = 0;
for (List aList : listList) {
@SuppressWarnings("unchecked")
List<Object> list = (List<Object>) aList;
@SuppressWarnings("unchecked")
Integer someInt = (Integer) list.get(1);
try {
retVal += (long) (someInt.intValue());
} catch (NullPointerException e) {
System.out.println(e);
}
}
return retVal;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution;
import org.apache.hadoop.hive.ql.exec.UDF;
import java.util.List;
import org.apache.commons.lang.StringUtils;
public class UDFListString extends UDF {
public String evaluate(Object a) {
if (a == null) {
return null;
}
@SuppressWarnings("unchecked")
List<Object> s = (List<Object>) a;
return StringUtils.join(s, ',');
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution;
import org.apache.hadoop.hive.ql.exec.UDF;
public class UDFStringString extends UDF {
public String evaluate(String s1, String s2) {
return s1 + " " + s2;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution;
import org.apache.hadoop.hive.ql.exec.UDF;
public class UDFTwoListList extends UDF {
public String evaluate(Object o1, Object o2) {
UDFListListInt udf = new UDFListListInt();
return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2));
}
}
...@@ -17,33 +17,37 @@ ...@@ -17,33 +17,37 @@
package org.apache.spark.sql.hive.execution package org.apache.spark.sql.hive.execution
import java.io.{DataOutput, DataInput} import java.io.{DataInput, DataOutput}
import java.util import java.util
import java.util.Properties import java.util.Properties
import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.Row import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.io.Writable
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int)
// Case classes for the custom UDF's.
case class IntegerCaseClass(i: Int)
case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)])
case class StringCaseClass(s: String)
case class ListStringCaseClass(l: Seq[String])
/** /**
* A test suite for Hive custom UDFs. * A test suite for Hive custom UDFs.
*/ */
class HiveUdfSuite extends HiveComparisonTest { class HiveUdfSuite extends QueryTest {
import TestHive._
test("spark sql udf test that returns a struct") { test("spark sql udf test that returns a struct") {
registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
...@@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest { ...@@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest {
} }
test("SPARK-2693 udaf aggregates test") { test("SPARK-2693 udaf aggregates test") {
assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)
}
test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize(
IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
testData.registerTempTable("integerTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
checkAnswer(
sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
Seq(Seq("1"), Seq("2")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
TestHive.reset()
}
test("UDFListListInt") {
val testData = TestHive.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
testData.registerTempTable("listListIntTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
checkAnswer(
sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
Seq(Seq(0), Seq(2), Seq(13)))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
TestHive.reset()
}
test("UDFListString") {
val testData = TestHive.sparkContext.parallelize(
ListStringCaseClass(Seq("a", "b", "c")) ::
ListStringCaseClass(Seq("d", "e")) :: Nil)
testData.registerTempTable("listStringTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
checkAnswer(
sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
Seq(Seq("a,b,c"), Seq("d,e")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
TestHive.reset()
}
test("UDFStringString") {
val testData = TestHive.sparkContext.parallelize(
StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
testData.registerTempTable("stringTable")
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
checkAnswer(
sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
Seq(Seq("hello world"), Seq("hello goodbye")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")
TestHive.reset()
}
test("UDFTwoListList") {
val testData = TestHive.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
Nil)
testData.registerTempTable("TwoListTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
checkAnswer(
sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
TestHive.reset()
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment