Skip to content
Snippets Groups Projects
Commit 3da2305e authored by Cody Koeninger's avatar Cody Koeninger
Browse files

code cleanup per rxin comments

parent dfac0aa5
No related branches found
No related tags found
No related merge requests found
...@@ -5,23 +5,27 @@ import java.sql.{Connection, ResultSet} ...@@ -5,23 +5,27 @@ import java.sql.{Connection, ResultSet}
import spark.{Logging, Partition, RDD, SparkContext, TaskContext} import spark.{Logging, Partition, RDD, SparkContext, TaskContext}
import spark.util.NextIterator import spark.util.NextIterator
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
}
/** /**
An RDD that executes an SQL query on a JDBC connection and reads results. * An RDD that executes an SQL query on a JDBC connection and reads results.
@param getConnection a function that returns an open Connection. * @param getConnection a function that returns an open Connection.
The RDD takes care of closing the connection. * The RDD takes care of closing the connection.
@param sql the text of the query. * @param sql the text of the query.
The query must contain two ? placeholders for parameters used to partition the results. * The query must contain two ? placeholders for parameters used to partition the results.
E.g. "select title, author from books where ? <= id and id <= ?" * E.g. "select title, author from books where ? <= id and id <= ?"
@param lowerBound the minimum value of the first placeholder * @param lowerBound the minimum value of the first placeholder
@param upperBound the maximum value of the second placeholder * @param upperBound the maximum value of the second placeholder
The lower and upper bounds are inclusive. * The lower and upper bounds are inclusive.
@param numPartitions the number of partitions. * @param numPartitions the number of partitions.
Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
the query would be executed twice, once with (1, 10) and once with (11, 20) * the query would be executed twice, once with (1, 10) and once with (11, 20)
@param mapRow a function from a ResultSet to a single row of the desired result type(s). * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
This should only call getInt, getString, etc; the RDD takes care of calling next. * This should only call getInt, getString, etc; the RDD takes care of calling next.
The default maps a ResultSet to an array of Object. * The default maps a ResultSet to an array of Object.
*/ */
class JdbcRDD[T: ClassManifest]( class JdbcRDD[T: ClassManifest](
sc: SparkContext, sc: SparkContext,
getConnection: () => Connection, getConnection: () => Connection,
...@@ -29,26 +33,33 @@ class JdbcRDD[T: ClassManifest]( ...@@ -29,26 +33,33 @@ class JdbcRDD[T: ClassManifest](
lowerBound: Long, lowerBound: Long,
upperBound: Long, upperBound: Long,
numPartitions: Int, numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray) mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging { extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = override def getPartitions: Array[Partition] = {
ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions). // bounds are inclusive, hence the + 1 here and - 1 on end
filter(! _.isEmpty). val length = 1 + upperBound - lowerBound
zipWithIndex. (0 until numPartitions).map(i => {
map(x => new JdbcPartition(x._2, x._1.head, x._1.last)). val start = lowerBound + ((i * length) / numPartitions).toLong
toArray val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
new JdbcPartition(i, start, end)
}).toArray
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addOnCompleteCallback{ () => closeIfNeeded() } context.addOnCompleteCallback{ () => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition] val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection() val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
// force mysql driver to stream rather than pull entire resultset into memory
// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE) stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
} }
stmt.setLong(1, part.lower) stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper) stmt.setLong(2, part.upper)
val rs = stmt.executeQuery() val rs = stmt.executeQuery()
...@@ -81,14 +92,10 @@ class JdbcRDD[T: ClassManifest]( ...@@ -81,14 +92,10 @@ class JdbcRDD[T: ClassManifest](
} }
} }
} }
}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
} }
object JdbcRDD { object JdbcRDD {
val resultSetToObjectArray = (rs: ResultSet) => def resultSetToObjectArray(rs: ResultSet) = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment