Skip to content
Snippets Groups Projects
Commit b8f88d32 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-5702][SQL] Allow short names for built-in data sources.

Also took the chance to fixed up some style ...

Author: Reynold Xin <rxin@databricks.com>

Closes #4489 from rxin/SPARK-5702 and squashes the following commits:

74f42e3 [Reynold Xin] [SPARK-5702][SQL] Allow short names for built-in data sources.
parent b9691826
No related branches found
No related tags found
No related merge requests found
...@@ -48,11 +48,6 @@ private[sql] object JDBCRelation { ...@@ -48,11 +48,6 @@ private[sql] object JDBCRelation {
* exactly once. The parameters minValue and maxValue are advisory in that * exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data * incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented. * will fail to be represented.
*
* @param column - Column name. Must refer to a column of integral type.
* @param numPartitions - Number of partitions
* @param minValue - Smallest value of column. Advisory.
* @param maxValue - Largest value of column. Advisory.
*/ */
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
...@@ -68,12 +63,17 @@ private[sql] object JDBCRelation { ...@@ -68,12 +63,17 @@ private[sql] object JDBCRelation {
var currentValue: Long = partitioning.lowerBound var currentValue: Long = partitioning.lowerBound
var ans = new ArrayBuffer[Partition]() var ans = new ArrayBuffer[Partition]()
while (i < numPartitions) { while (i < numPartitions) {
val lowerBound = (if (i != 0) s"$column >= $currentValue" else null) val lowerBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride currentValue += stride
val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null) val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val whereClause = (if (upperBound == null) lowerBound val whereClause =
else if (lowerBound == null) upperBound if (upperBound == null) {
else s"$lowerBound AND $upperBound") lowerBound
} else if (lowerBound == null) {
upperBound
} else {
s"$lowerBound AND $upperBound"
}
ans += JDBCPartition(whereClause, i) ans += JDBCPartition(whereClause, i)
i = i + 1 i = i + 1
} }
...@@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider { ...@@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider {
if (driver != null) Class.forName(driver) if (driver != null) Class.forName(driver)
if ( if (partitionColumn != null
partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) { && (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified") sys.error("Partitioning incompletely specified")
} }
...@@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider { ...@@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider {
private[sql] case class JDBCRelation( private[sql] case class JDBCRelation(
url: String, url: String,
table: String, table: String,
parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan { parts: Array[Partition])(@transient val sqlContext: SQLContext)
extends PrunedFilteredScan {
override val schema = JDBCRDD.resolveTable(url, table) override val schema = JDBCRDD.resolveTable(url, table)
......
...@@ -20,6 +20,7 @@ package org.apache.spark.sql.json ...@@ -20,6 +20,7 @@ package org.apache.spark.sql.json
import java.io.IOException import java.io.IOException
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
......
...@@ -234,65 +234,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { ...@@ -234,65 +234,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
primitiveType primitiveType
} }
object ResolvedDataSource { private[sql] object ResolvedDataSource {
def apply(
sqlContext: SQLContext, private val builtinSources = Map(
userSpecifiedSchema: Option[StructType], "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
provider: String, "json" -> classOf[org.apache.spark.sql.json.DefaultSource],
options: Map[String, String]): ResolvedDataSource = { "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
)
/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
if (builtinSources.contains(provider)) {
return builtinSources(provider)
}
val loader = Utils.getContextOrSparkClassLoader val loader = Utils.getContextOrSparkClassLoader
val clazz: Class[_] = try loader.loadClass(provider) catch { try {
loader.loadClass(provider)
} catch {
case cnf: java.lang.ClassNotFoundException => case cnf: java.lang.ClassNotFoundException =>
try loader.loadClass(provider + ".DefaultSource") catch { try {
loader.loadClass(provider + ".DefaultSource")
} catch {
case cnf: java.lang.ClassNotFoundException => case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider") sys.error(s"Failed to load class for data source: $provider")
} }
} }
}
/** Create a [[ResolvedDataSource]] for reading data in. */
def apply(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
val relation = userSpecifiedSchema match { val relation = userSpecifiedSchema match {
case Some(schema: StructType) => { case Some(schema: StructType) => clazz.newInstance() match {
clazz.newInstance match { case dataSource: SchemaRelationProvider =>
case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider =>
case dataSource: org.apache.spark.sql.sources.RelationProvider => sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
}
} }
case None => {
clazz.newInstance match { case None => clazz.newInstance() match {
case dataSource: RelationProvider => case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
}
} }
} }
new ResolvedDataSource(clazz, relation) new ResolvedDataSource(clazz, relation)
} }
/** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
def apply( def apply(
sqlContext: SQLContext, sqlContext: SQLContext,
provider: String, provider: String,
mode: SaveMode, mode: SaveMode,
options: Map[String, String], options: Map[String, String],
data: DataFrame): ResolvedDataSource = { data: DataFrame): ResolvedDataSource = {
val loader = Utils.getContextOrSparkClassLoader val clazz: Class[_] = lookupDataSource(provider)
val clazz: Class[_] = try loader.loadClass(provider) catch { val relation = clazz.newInstance() match {
case cnf: java.lang.ClassNotFoundException =>
try loader.loadClass(provider + ".DefaultSource") catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
}
}
val relation = clazz.newInstance match {
case dataSource: CreatableRelationProvider => case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data) dataSource.createRelation(sqlContext, mode, options, data)
case _ => case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
} }
new ResolvedDataSource(clazz, relation) new ResolvedDataSource(clazz, relation)
} }
} }
...@@ -405,6 +413,5 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St ...@@ -405,6 +413,5 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St
/** /**
* The exception thrown from the DDL parser. * The exception thrown from the DDL parser.
* @param message
*/ */
protected[sql] class DDLException(message: String) extends Exception(message) protected[sql] class DDLException(message: String) extends Exception(message)
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources
import org.scalatest.FunSuite
class ResolvedDataSourceSuite extends FunSuite {
test("builtin sources") {
assert(ResolvedDataSource.lookupDataSource("jdbc") ===
classOf[org.apache.spark.sql.jdbc.DefaultSource])
assert(ResolvedDataSource.lookupDataSource("json") ===
classOf[org.apache.spark.sql.json.DefaultSource])
assert(ResolvedDataSource.lookupDataSource("parquet") ===
classOf[org.apache.spark.sql.parquet.DefaultSource])
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment