1. 前言
这是一篇挂羊头卖狗肉的文章,事实上,本文要描述的内容,和Spark Streaming没有什么关系。
在上一篇文章http://www.jianshu.com/p/a73c0c95d2fe 我们写了如何通过Spark Streaming向数据库中插入数据。可能你已经发现了,数据是逐条插入数据库的,效率底下。那么如何提高插入数据库的效率呢?
数据库写是个IO任务,并行不一定能够加速写入数据库的速度。我们主要说下批量提交和Bulk Copy Insert的方式。
2.批量提交
批量提交,就是JDBC Statment的executeBatch,直接看代码吧。
/**
* 从Kafka中读取数据,并把数据成批写入数据库
*/
object KafkaToDB {
val logger = LoggerFactory.getLogger(this.getClass)
def main(args: Array[String]): Unit = {
// 参数校验
if (args.length < 2) {
System.err.println(
s"""
|Usage: KafkaToDB <brokers> <topics>
| <brokers> is a list of one or more Kafka brokers
| <topics> is a list of one or more kafka topics to consume from
|""".stripMargin)
System.exit(1)
}
// 处理参数
val Array(brokers, topics) = args
// topic以“,”分割
val topicSet: Set[String] = topics.split(",").toSet
val kafkaParams: Map[String, Object] = Map[String, Object](
"bootstrap.servers" -> brokers,
"key.deserializer" -> classOf[StringDeserializer],
"value.deserializer" -> classOf[StringDeserializer],
"group.id" -> "example",
"auto.offset.reset" -> "latest",
"enable.auto.commit" -> (false: java.lang.Boolean)
)
// 创建上下文,以每1秒间隔的数据作为一批
val sparkConf = new SparkConf().setAppName("KafkaToDB")
val streamingContext = new StreamingContext(sparkConf, Seconds(2))
// 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
val stream = KafkaUtils.createDirectStream[String, String](
streamingContext,
PreferConsistent,
Subscribe[String, String](topicSet, kafkaParams)
)
// 2. DStream上的转换操作
// 取消息中的value数据,以英文逗号分割,并转成Tuple3
val values = stream.map(_.value.split(","))
.filter(x => x.length == 3)
.map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))
// 输入前10条到控制台,方便调试
values.print()
// 3.同foreachRDD保存到数据库
val sql = "insert into kafka_message(timeseq,timeseq2, thread, message) values (?,?,?,?)"
values.foreachRDD(rdd => {
val count = rdd.count()
println("-----------------count:" + count)
if (count > 0) {
rdd.foreachPartition(partitionOfRecords => {
val conn = ConnectionPool.getConnection.orNull
if (conn != null) {
val ps = conn.prepareStatement(sql)
try{
// 关闭自动执提交
conn.setAutoCommit(false)
partitionOfRecords.foreach(data => {
ps.setString(1, data._1)
ps.setString(2,System.currentTimeMillis().toString)
ps.setString(3, data._2)
ps.setString(4, data._3)
ps.addBatch()
})
ps.executeBatch()
conn.commit()
} catch {
case e: Exception =>
logger.error("Error in execution of insert. " + e.getMessage)
}finally {
ps.close()
ConnectionPool.closeConnection(conn)
}
}
})
}
})
streamingContext.start() // 启动计算
streamingContext.awaitTermination() // 等待中断结束计算
}
}
3. Bulk Copy Insert
我们使用的是PostgreSQL,其数据库JDBC驱动程序提供了Copy Insert的API,其主要过程是:
- 1.获取数据库连接
- 2.创建CopyManager
- 3.把Spark Streaming中的流数据封装成InputStream
- 4.执行CopyInsert
import java.sql.Connection
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.SparkConf
import org.apache.spark.streaming.kafka010.ConsumerStrategies._
import org.apache.spark.streaming.kafka010.KafkaUtils
import org.apache.spark.streaming.kafka010.LocationStrategies._
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
import org.slf4j.LoggerFactory
object CopyInsert {
val logger = LoggerFactory.getLogger(this.getClass)
def main(args: Array[String]): Unit = {
// 参数校验
if (args.length < 4) {
System.err.println(
s"""
|Usage: CopyInsert <brokers> <topics> <duration> <batchsize>
| <brokers> is a list of one or more Kafka brokers
| <topics> is a list of one or more kafka topics to consume from
|""".stripMargin)
System.exit(1)
}
// 处理参数
val Array(brokers, topics,duration,batchsize) = args
// topic以“,”分割
val topicSet: Set[String] = topics.split(",").toSet
val kafkaParams: Map[String, Object] = Map[String, Object](
"bootstrap.servers" -> brokers,
"key.deserializer" -> classOf[StringDeserializer],
"value.deserializer" -> classOf[StringDeserializer],
"group.id" -> "example",
"auto.offset.reset" -> "latest",
"enable.auto.commit" -> (false: java.lang.Boolean)
)
// 创建上下文,以每1秒间隔的数据作为一批
val sparkConf = new SparkConf().setAppName("CopyInsertIntoPostgreSQL")
val streamingContext = new StreamingContext(sparkConf, Seconds(duration.toInt))
// 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
val stream = KafkaUtils.createDirectStream[String, String](
streamingContext,
PreferConsistent,
Subscribe[String, String](topicSet, kafkaParams)
)
// 2. DStream上的转换操作
// 取消息中的value数据,以英文逗号分割,并转成Tuple3
val values = stream.map(_.value.split(","))
.filter(x => x.length == 3)
.map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))
// 输入前10条到控制台,方便调试
values.print()
// 3.同foreachRDD保存到数据库
// http://rostislav-matl.blogspot.jp/2011/08/fast-inserts-to-postgresql-with-jdbc.html
values.foreachRDD(rdd => {
val count = rdd.count()
println("-----------------count:" + count)
if (count > 0) {
rdd.foreachPartition(partitionOfRecords => {
val start = System.currentTimeMillis()
val conn: Connection = ConnectionPool.getConnection.orNull
if (conn != null) {
val batch = batchsize.toInt
var counter: Int = 0
val sb: StringBuilder = new StringBuilder()
// 获取数据库连接
val baseConnection = conn.getMetaData.getConnection.asInstanceOf[BaseConnection]
// 创建CopyManager
val cpManager: CopyManager = new CopyManager(baseConnection)
partitionOfRecords.foreach(record => {
counter += 1
sb.append(record._1).append(",")
.append(System.currentTimeMillis()).append(",")
.append(record._2).append(",")
.append(record._3).append("\n")
if (counter == batch) {
// 构建输入流
val in: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
// 执行copyin
cpManager.copyIn("COPY kafka_message FROM STDIN WITH CSV", in)
println("-----------------batch---------------: " + batch)
counter = 0
sb.delete(0, sb.length)
closeInputStream(in)
}
})
val lastIn: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
cpManager.copyIn("COPY kafka_message2 FROM STDIN WITH CSV", lastIn)
sb.delete(0, sb.length)
counter = 0
closeInputStream(lastIn)
val end = System.currentTimeMillis()
println("-----------------duration---------------ms :" + (end - start))
}
})
}
})
streamingContext.start() // 启动计算
streamingContext.awaitTermination() // 等待中断结束计算
}
def closeInputStream(in: InputStream): Unit ={
try{
in.close()
}catch{
case e: IOException =>
logger.error("Error on close InputStream. " + e.getMessage)
}
}
}
其它数据库应该也有bulk load的方式,例如MySQL,com.mysql.jdbc.Statment中有setLocalInfileInputStream方法,功能应该和上述的Copy Insert类似,但我还没有写例子验证。文档里有如下的描述,供参考。原文地址
Sets an InputStream instance that will be used to send data to the MySQL server for a “LOAD DATA LOCAL INFILE” statement rather than a FileInputStream or URLInputStream that represents the path given as an argument to the statement.
(完)