diff --git a/.rat-excludes b/.rat-excludes index 236c2db05367c0386b55927c7fac3f9b1333c937..72771465846b8561552fa09adf697c6c52769e66 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -93,3 +93,4 @@ INDEX .lintr gen-java.* .*avpr +org.apache.spark.sql.sources.DataSourceRegister diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000000000000000000000000000000..cc32d4b72748ef07548dd88aa8530d6ef752a301 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.jdbc.DefaultSource +org.apache.spark.sql.json.DefaultSource +org.apache.spark.sql.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad57b92c46aab3c6cf5dc3016d49c4ad1..8c2f297e42458d0e7214a7752779115d56f2f4c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.execution.datasources +import java.util.ServiceLoader + +import scala.collection.Iterator +import scala.collection.JavaConversions._ import scala.language.{existentials, implicitConversions} +import scala.util.{Failure, Success, Try} import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -190,37 +195,32 @@ private[sql] class DDLParser( } } -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) +private[sql] object ResolvedDataSource extends Logging { /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val provider2 = s"$provider.DefaultSource" val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider", error) } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 41d0ecb4bbfbf9bf9f15033b8e6f581b294ec6d8..48d97ced9ca0aaa45e7e070822bfcdcecae40c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -77,7 +77,10 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider { +private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { + + def format(): String = "jdbc" + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 10f1367e6984ced8268747e0bd40c34881160293..b34a272ec547fe0b7e4aa184c4e0f0fe340f26a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "json" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 48009b2fd007d8fb4f410d2d42466756e11b778d..b6db71b5b8a62ca9a81d4092d98d3c4ef34f0c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "parquet" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c5b7ee73eb78433f616bb3648715aab9abccdc95..4aafec0e2df271328ae59f932912562dbe4a1749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * ex: parquet.DefaultSource.format = "parquet". + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source, + * ex: override def format(): String = "parquet" + */ + def format(): String +} + /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000000000000000000000000000000..cfd7889b4ac2ca180bcbc8c0cbf528ec66866dd5 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1a4d41b02ca686991bd887f8b61a9e47e8f85849 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def format(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("Loading Orc") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000000000000000000000000000000..4a774fbf1fdf83be03f4f4e66b268c18ad01f2b4 --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c8704b47f286baecdad5e38fddbc75823e38630..0c344c63fde3f48a29d144c0c9f7c9a7b076f3e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -47,7 +47,10 @@ import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "orc" + def createRelation( sqlContext: SQLContext, paths: Array[String],