Spark自带的JdbcRDD,只支持Long类型的分区参数,分区必须是一个Long区间。很多情况下,这种方式都不适用。
改写Spark JdbcRDD,支持自己定义分区查询条件(转)
我对JdbcRDD进行了改写,可支持完全自定义分区条件。
主要实现思路:
把设置查询参数部分改写成可以自定义的函数。这样自己想怎么样设置分区参数都可以。
直接上代码吧:
//////////////////////////////////////////////////////////////////////////////////
packagecom.dt.scala.spark.streaming
importorg.apache.spark.TaskContext
importorg.apache.spark.api.java.JavaSparkContext.fakeClassTag
importorg.apache.spark.api.java.JavaSparkContext
importorg.apache.spark.api.java.function.{Function=> JFunction}
importorg.apache.spark.api.java.JavaRDD
importorg.apache.spark.SparkContext
importorg.apache.spark.util.NextIterator
importscala.reflect.ClassTag
importjava.sql.ResultSet
importjava.sql.Connection
importorg.apache.spark.Partition
importorg.apache.spark.Logging
importjava.sql.PreparedStatement
classCustomizedJdbcPartition(idx:Int,parameters:Map[String,Object])extendsPartition{
override defindex= idx
valpartitionParameters=parameters
}
//TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private
/**
* An RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JdbcRDDSuite.
*@param getConnectiona function that returns an open Connection.
* The RDD takes care of closing the connection.
*@param sqlthe text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
*@param lowerBoundthe minimum value of the first placeholder
*@param upperBoundthe maximum value of the second placeholder
* The lower and upper bounds are inclusive.
*@param numPartitionsthe number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*@param mapRowa function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
class CustomizedJdbcRDD[T:ClassTag](
sc:SparkContext,
getConnection: () =>Connection,
sql:String,
getCustomizedPartitions: () =>Array[Partition],
prepareStatement: (PreparedStatement,CustomizedJdbcPartition) =>PreparedStatement,
mapRow: (ResultSet) =>T=CustomizedJdbcRDD.resultSetToObjectArray_)
extendsRDD[T](sc,Nil)withLogging{
override defgetPartitions:Array[Partition] = {
getCustomizedPartitions();
}
override def compute(thePart:Partition,context:TaskContext) =newNextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
valpart= thePart.asInstanceOf[CustomizedJdbcPartition]
valconn= getConnection()
valstmt=conn.prepareStatement(sql,ResultSet.TYPE_FORWARD_ONLY,ResultSet.CONCUR_READ_ONLY)
// setFetchSize(Integer.MIN_VALUE) is aMySQLdriver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
try{
if(conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: "+stmt.getFetchSize +" to force MySQL streaming ")
}
}catch{
caseex:Exception=> {
//ex.printStackTrace();
}
}
prepareStatement(stmt,part)
valrs=stmt.executeQuery()
override defgetNext:T= {
if(rs.next()) {
mapRow(rs)
}else{
finished=true
null.asInstanceOf[T]
}
}
override def close() {
try{
if(null!=rs&& !rs.isClosed()) {
rs.close()
}
}catch{
casee:Exception=> logWarning("Exception closing resultset",e)
}
try{
if(null!=stmt&& !stmt.isClosed()) {
stmt.close()
}
}catch{
casee:Exception=> logWarning("Exception closing statement",e)
}
try{
if(null!=conn&& !conn.isClosed()) {
conn.close()
}
logInfo("closed connection")
}catch{
casee:Exception=> logWarning("Exception closing connection",e)
}
}
}
}
object CustomizedJdbcRDD{
def resultSetToObjectArray(rs:ResultSet):Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i +1))
}
trait ConnectionFactoryextendsSerializable{
@throws[Exception]
defgetConnection:Connection
}
/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
*@param connectionFactorya factory that returns an open Connection.
* The RDD takes care of closing the connection.
*@param sqlthe text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
*@param lowerBoundthe minimum value of the first placeholder
*@param upperBoundthe maximum value of the second placeholder
* The lower and upper bounds are inclusive.
*@param numPartitionsthe number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*@param mapRowa function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
def create[T](
sc:JavaSparkContext,
connectionFactory:ConnectionFactory,
sql:String,
getCustomizedPartitions: () =>Array[Partition],
prepareStatement: (PreparedStatement,CustomizedJdbcPartition) =>PreparedStatement,
mapRow:JFunction[ResultSet,T]):JavaRDD[T] = {
valjdbcRDD=newCustomizedJdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
getCustomizedPartitions,
prepareStatement,
(resultSet:ResultSet) => mapRow.call(resultSet))(fakeClassTag)
newJavaRDD[T](jdbcRDD)(fakeClassTag)
}
/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
* converted into a`Object`array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
*@param connectionFactorya factory that returns an open Connection.
* The RDD takes care of closing the connection.
*@param sqlthe text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
*@param lowerBoundthe minimum value of the first placeholder
*@param upperBoundthe maximum value of the second placeholder
* The lower and upper bounds are inclusive.
*@param numPartitionsthe number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*/
def create(
sc:JavaSparkContext,
connectionFactory:ConnectionFactory,
sql:String,
getCustomizedPartitions: () =>Array[Partition],
prepareStatement: (PreparedStatement,CustomizedJdbcPartition) =>PreparedStatement):JavaRDD[Array[Object]] = {
valmapRow=newJFunction[ResultSet,Array[Object]] {
override def call(resultSet:ResultSet):Array[Object] = {
resultSetToObjectArray(resultSet)
}
}
create(sc,connectionFactory,sql,getCustomizedPartitions,prepareStatement,mapRow)
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
下面是一段简单的测试代码:
package org.apache.spark
import java.sql.Connection
import java.sql.DriverManager
import org.apache.spark.rdd.CustomizedJdbcRDD
import org.apache.spark.rdd.CustomizedJdbcPartition
import java.sql.PreparedStatement
object HiveRDDTest {
private val driverName = "org.apache.hive.jdbc.HiveDriver";
private val tableName = "COLLECT_DATA";
private var connection: Connection = null;
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("HiveRDDTest").setMaster("local[2]");
val sc = new SparkContext(conf);
Class.forName(driverName);
val data = new CustomizedJdbcRDD(sc,
//创建获取JDBC连接函数
() => {
DriverManager.getConnection("jdbc:hive2://192.168.31.135:10000/default", "spark", "");
},
//设置查询SQL
"select * from collect_data where host=?",
//创建分区函数
() => {
val partitions=new Array[Partition](1);
var parameters=Map[String, Object]();
parameters+=("host" -> "172.18.26.11");
val partition=new CustomizedJdbcPartition(0, parameters);
partitions(0)=partition;
partitions;
},
//为每个分区设置查询条件(基于上面设置的SQL语句)
(stmt:PreparedStatement, partition:CustomizedJdbcPartition) => {
stmt.setString(1, partition.asInstanceOf[CustomizedJdbcPartition]
.partitionParameters.get("host").get.asInstanceOf[String])
stmt;
}
);
println(data.count());
}
}