-
-
Save bmatheny/2245792 to your computer and use it in GitHub Desktop.
MySQL JDBC connection pool for Scala + Finagle
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (C) 2012 Benoit Sigoure | |
// Copyright (C) 2012 StumbleUpon, Inc. | |
// This library is free software: you can redistribute it and/or modify it | |
// under the terms of the GNU Lesser General Public License as published by | |
// the Free Software Foundation, either version 2.1 of the License, or (at your | |
// option) any later version. This program is distributed in the hope that it | |
// will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty | |
// of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser | |
// General Public License for more details. You should have received a copy | |
// of the GNU Lesser General Public License along with this program. If not, | |
// see <http://www.gnu.org/licenses/>. | |
package com.stumbleupon.backends | |
import java.sql.Connection | |
import java.sql.DriverManager | |
import java.sql.PreparedStatement | |
import java.sql.ResultSet | |
import java.sql.SQLDataException | |
import java.sql.SQLFeatureNotSupportedException | |
import java.sql.SQLIntegrityConstraintViolationException | |
import java.sql.SQLNonTransientException | |
import java.sql.SQLRecoverableException | |
import java.sql.SQLSyntaxErrorException | |
import java.util.concurrent.ArrayBlockingQueue | |
import java.util.concurrent.Executors | |
import java.util.concurrent.ThreadFactory | |
import java.util.concurrent.atomic.AtomicInteger | |
import java.util.concurrent.TimeUnit.MILLISECONDS | |
import scala.collection.mutable.ArrayBuffer | |
import org.slf4j.LoggerFactory | |
import com.mysql.jdbc.log.Log | |
import com.twitter.conversions.time._ | |
import com.twitter.util.Duration | |
import com.twitter.util.Future | |
import com.twitter.util.FuturePool | |
import com.stumbleupon.common.Counter // This is like an AtomicLong except based on jsr166e.LongAdder | |
/** | |
* Configuration for a connection Pool. | |
* @param servers A list of "ip:port". | |
* @param user MySQL username to connect as. | |
* @param pass MySQL password for that user. | |
* @param schema Name of the DB schema to use. | |
*/ | |
final case class PoolConfig(servers: Seq[String], user: String, pass: String, schema: String) { | |
/** Like `equals' but ignores the order in `servers' in case they were shuffled. */ | |
def equivalentTo(other: PoolConfig): Boolean = | |
user == other.user && pass == other.pass && schema == other.schema && | |
servers.sorted == other.servers.sorted | |
} | |
/** | |
* Wrapper for JDBC connections. | |
* We have to wrap every connection just so we can remember which server this | |
* connection is connected to, so we can reconnect when something bad happens. | |
* Because, yes, believe it or not, there's no way to reliably extract this | |
* information from a JDBC connection object. | |
* @param server A "host:port" string. | |
*/ | |
final case class MySQLConnection(server: String, connection: Connection) { | |
def prepareStatement(sql: String) = connection.prepareStatement(sql) | |
def close = connection.close | |
} | |
/** | |
* A connection pool for asynchronous operations. | |
* For each connection, there is a dedicated thread, because MySQL doesn't | |
* have an asynchronous RPC protocol, and because JDBC doesn't have an | |
* asynchronous API. | |
* @param cfg Configuration for this connection pool. | |
* @param options A "query string" passed as-is in the JDBC URL. | |
* @param readonly Whether or not to set the connection read-only mode. | |
* @param appName Name of the current app (e.g. "honeybadger"). | |
*/ | |
final class ConnectionPool(cfg: PoolConfig, | |
options: String, | |
val readonly: Boolean, | |
appName: String) { | |
import ConnectionPool._ | |
ensureDriverLoaded | |
@volatile private[this] var conf = cfg | |
private[this] val pool = makePool(cfg.servers.length) | |
@volatile private[this] var connections: ArrayBlockingQueue[MySQLConnection] = _ | |
createConnections() // Populates `connections'. | |
/** Returns the current configuration of this pool. */ | |
def config = conf | |
/** | |
* Attempts to apply the new configuration given to this pool. | |
* Changes are applied atomically without disruptive ongoing traffic. | |
* If successful, this closes all the connections and replaces them all with | |
* new connections. | |
* If there's an exception thrown, changes are rolled back first and both | |
* the configuration and the connection pool will remain unchanged. | |
* <strong>WARNING:</strong> this function is blocking, and might take a | |
* while (maybe several seconds) to return. | |
* @param newcfg The new configuration to apply to this pool. The | |
* configuration is assumed to be sane. | |
* @throws SQLException if something bad happens (e.g. being unable to open | |
* a connection to any one of the hosts for whatever reason). | |
*/ | |
def updateConfig(newcfg: PoolConfig) { | |
// Almost everything we do is thread-safe but in order to guarantee that | |
// we can correctly rollback the changes in case of an exception, and in | |
// order to ensure that we only attempt to apply one change at a time, | |
// it's much safer and easier to make this entire method synchronized. | |
synchronized { | |
val prevconns = connections // volatile-read | |
val prevcfg = conf // volatile-read | |
try { | |
conf = newcfg | |
createConnections() // volatile-write on connections | |
// Success! Now dispose of the previous connections, to not leak them. | |
try { | |
closeAllConnections(prevcfg, prevconns) | |
} catch { | |
case e: Exception => | |
log.warn("Uncaught exception while closing an old connection after" | |
+ " reloading a new configuration", e) | |
} | |
} catch { | |
case e: Exception => | |
// Roll-back. | |
connections = prevconns // volatile-write | |
conf = prevcfg // volatile-write | |
throw e | |
} | |
} | |
} | |
/** Creates and populates all the connections for this pool .*/ | |
private def createConnections() { | |
val newconns = new ArrayBlockingQueue[MySQLConnection](conf.servers.length) | |
conf.servers foreach { server => // server is already "ip:port". | |
newconns.add(newConnection(server)) | |
} | |
connections = newconns // commit: volatile-write | |
} | |
/** How many queries did we send to MySQL. */ | |
private[this] val queries = new Counter | |
/** How many exceptions we got from JDBC. */ | |
private[this] val exceptions = new Counter | |
/** Returns the number of queries sent to MySQL. */ | |
def queriesSent: Long = queries.get | |
/** Returns the number of exception caught while MySQL stuff. */ | |
def exceptionsCaught: Long = exceptions.get | |
/** Closes all connections and releases all threads. */ | |
def shutdown() { | |
pool.executor.shutdown() | |
closeAllConnections(conf, connections) | |
} | |
/** | |
* Executes a SELECT statement on the database. | |
* @param f Function called on each row returned by the database. This | |
* function is called with the connection locked, so if this function takes | |
* time, it will prevent the connection from beind reused for another query. | |
* @param sql The SQL statement, e.g. "SELECT foo FROM t WHERE id = ?" | |
* @param params The parameters to substitute in the `?' placeholders. | |
* These parameters don't need to be escaped as prepared statements are used | |
* and they already prevent SQL injections. | |
* @return A future sequence of things returned by `f'. | |
* @throws SQLException (async) if something bad happens (sorry I don't know more). | |
*/ | |
def select[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = { | |
pool(execute(f, "/*" + appName + "*/ " + sql, params)) | |
} | |
// TODO(tsuna): Provide code for insert, update etc, not just select. | |
private def execute[T](f: ResultSet => T, sql: String, params: Seq[Any]): Seq[T] = { | |
queries.increment | |
val connpool = this.connections // volatile-read | |
var connection = connpool.poll | |
if (connection == null) { // Should never happen. | |
// We have as many threads as connections so this can only happen if a | |
// thread is leaking a connection, which would be really bad. | |
val e = new IllegalStateException("WTF? Couldn't get a connection from the pool.") | |
exceptions.increment | |
log.error(e.getMessage) | |
throw e | |
} | |
try { | |
val statement = connection.prepareStatement(sql) | |
try { | |
bindParameters(statement, params) | |
if (log.isDebugEnabled) | |
log.debug(connection.server + ": " + sql | |
+ " " + params.mkString("(", ", ", ")")) | |
val rs = statement.executeQuery | |
try { | |
val results = new ArrayBuffer[T] | |
while (rs.next) { | |
results += f(rs) | |
} | |
results | |
} finally { | |
rs.close | |
} | |
} finally { | |
statement.close | |
} | |
} catch { | |
case e: SQLSyntaxErrorException => | |
logAndRethrow(connection, "Syntax error in SQL query", | |
sql, params, e) | |
case e: SQLIntegrityConstraintViolationException => | |
logAndRethrow(connection, "Integrity constraint violated by SQL query", | |
sql, params, e) | |
case e: SQLFeatureNotSupportedException => | |
logAndRethrow(connection, "Feature not supported in SQL query", | |
sql, params, e) | |
case e: SQLDataException => | |
logAndRethrow(connection, "Data exception caused by SQL query", | |
sql, params, e) | |
case e @ (_: SQLRecoverableException | _: SQLNonTransientException) => | |
// The remaining kinds of SQLNonTransientException are typically | |
// connection-level problems, so let's close this connection and get a | |
// new one. | |
// For a SQLRecoverableException the JDK javadoc manual says that "the | |
// recovery operation must include closing the current connection and | |
// getting a new connection". | |
connection.close // If we double-close it's OK, it's a no-op. | |
// Create a new connection, the `finally' block below will put it back | |
// in the pool. | |
connection = newConnection(connection.server) | |
// TODO(tsuna): If we wanted we could retry once here. | |
logAndRethrow(connection, "Error on connection when trying to execute", | |
sql, params, e) | |
case e: Throwable => | |
// TODO(tsuna): Should we close the connection here? I'm not sure. | |
logAndRethrow(connection, "Uncaught exception", sql, params, e) | |
} finally { | |
// Always return the connection to the pool. | |
connpool.put(connection) | |
} | |
} | |
/** Logs an exception and rethrows it. */ | |
private def logAndRethrow(connection: MySQLConnection, msg: String, | |
sql: String, params: Seq[Any], e: Throwable) = { | |
// This function must never throw a new exception of its own. | |
exceptions.increment | |
val cause = new StringBuilder | |
var exception = e | |
// Get names & messages of all exceptions in the chain. | |
while (exception != null) { | |
cause.append(", caused by ") | |
.append(e.getClass.getName) | |
.append(": ") | |
.append(e.getMessage) | |
exception = exception.getCause // previous exception causing this one. | |
} | |
log.error(connection.server + ": " + msg + ": " + sql | |
+ " with params " + params.mkString("(", ", ", ")") | |
+ cause) | |
throw e | |
} | |
private def bindParameters(statement: PreparedStatement, | |
params: TraversableOnce[Any]) { | |
bindParameters(statement, 1, params) | |
} | |
private def bindParameters(statement: PreparedStatement, | |
startIndex: Int, | |
params: TraversableOnce[Any]): Int = { | |
var index = startIndex | |
for (param <- params) { | |
param match { | |
case i: Int => statement.setInt(index, i) | |
case l: Long => statement.setLong(index, l) | |
case s: String => statement.setString(index, s) | |
case l: TraversableOnce[_] => | |
index = bindParameters(statement, index, l) - 1 | |
case p: Product => | |
index = bindParameters(statement, index, p.productIterator.toList) - 1 | |
case b: Array[Byte] => statement.setBytes(index, b) | |
case b: Boolean => statement.setBoolean(index, b) | |
case s: Short => statement.setShort(index, s) | |
case f: Float => statement.setFloat(index, f) | |
case d: Double => statement.setDouble(index, d) | |
case _ => | |
throw new IllegalArgumentException("Unsupported data type " | |
+ param.asInstanceOf[AnyRef].getClass.getName + ": " + param) | |
} | |
index += 1 | |
} | |
index | |
} | |
/** | |
* Returns a new MySQL connection. | |
* @param server A "host:port" string. | |
*/ | |
private def newConnection(server: String): MySQLConnection = { | |
val connection = | |
DriverManager.getConnection("jdbc:mysql://" + server + "/" + conf.schema + jdbcOptions, | |
conf.user, conf.pass) | |
connection.setReadOnly(readonly) | |
MySQLConnection(server, connection) | |
} | |
override def toString = "ConnectionPool(" + conf + ")" | |
} | |
object ConnectionPool { | |
private val log = LoggerFactory.getLogger(getClass) | |
private def ensureDriverLoaded = | |
// Load the MySQL JDBC driver. Yeah this looks like it has no side | |
// effect but it's required as it causes the driver to register itself | |
// with the JDBC DriverManager. Awesome design, right? | |
if (classOf[com.mysql.jdbc.Driver] == null) | |
throw new AssertionError("MySQL JDBC connector missing.") | |
/** Default options we use to connect to MySQL */ | |
val jdbcOptions: String = "?" + { | |
val options = Map( | |
"connectTimeout" -> 4.seconds, | |
"socketTimeout" -> 2.seconds, | |
"useServerPrepStmts" -> true, | |
"cachePrepStmts" -> true, | |
"cacheResultSetMetadata" -> true, | |
"cacheServerConfiguration" -> true, | |
"logger" -> classOf[MySQLLogger] | |
) | |
options.toList.map { case (option, value) => | |
option + "=" + (value match { | |
case c: Class[_] => c.getName | |
case d: Duration => d.inMilliseconds | |
case _ => value | |
}) | |
} mkString "&" | |
} | |
def readOnly(config: PoolConfig, appName: String): ConnectionPool = | |
new ConnectionPool(config, jdbcOptions, true, appName) | |
def readWrite(config: PoolConfig, appName: String): ConnectionPool = | |
new ConnectionPool(config, jdbcOptions, false, appName) | |
/** Creates a thread-pool to use the given connections. */ | |
private def makePool(size: Int) = { | |
val factory = new ThreadFactory { | |
val id = new AtomicInteger(0) | |
def newThread(r: Runnable) = | |
new Thread(r, "MySQL-" + id.incrementAndGet) | |
} | |
FuturePool(Executors.newFixedThreadPool(size, factory)) | |
} | |
/** | |
* Closes all the connections from the given pool with the given config. | |
* <strong>WARNING:</strong> this function is blocking, and might take a | |
* while (maybe several seconds) to clear up the pool. | |
*/ | |
private def closeAllConnections(conf: PoolConfig, | |
connections: ArrayBlockingQueue[MySQLConnection]) { | |
for (i <- 1 to conf.servers.length) { | |
// We're not serving a query to an end-user, and our goal is to | |
// close the connection but we don't want to wait forever in case | |
// the connection is somehow badly stuck. So allow quite a bit of | |
// time to grab a connection. | |
val connection = connections.poll(500, MILLISECONDS) | |
if (connection == null) { | |
log.error("Timeout while trying to get connection #" + i + " / " | |
+ conf.servers.length + ", connection will be leaked.") | |
} else { | |
connection.close | |
} | |
} | |
} | |
} | |
/** Class for MySQL's JDBC logging (otherwise it goes to stderr by default). */ | |
private final class MySQLLogger(name: String) extends Log { | |
val log = LoggerFactory.getLogger(name) | |
def isDebugEnabled: Boolean = log.isDebugEnabled | |
def isErrorEnabled: Boolean = log.isErrorEnabled | |
def isFatalEnabled: Boolean = log.isErrorEnabled | |
def isInfoEnabled: Boolean = log.isInfoEnabled | |
def isTraceEnabled: Boolean = log.isTraceEnabled | |
def isWarnEnabled: Boolean = log.isWarnEnabled | |
private def cast(msg: Any): String = | |
msg match { | |
case m: String => m | |
case _ => | |
throw new ClassCastException("argument isn't a String but a " | |
+ msg.asInstanceOf[AnyRef].getClass.getName + ": " + msg) | |
} | |
def logDebug(msg: Any) { | |
log.debug(cast(msg)) | |
} | |
def logDebug(msg: Any, e: Throwable) { | |
log.debug(cast(msg), e) | |
} | |
def logError(msg: Any) { | |
log.error(cast(msg)) | |
} | |
def logError(msg: Any, e: Throwable) { | |
log.error(cast(msg), e) | |
} | |
def logFatal(msg: Any) { | |
log.error("** FATAL ** " + cast(msg)) // Keep going anyway. | |
} | |
def logFatal(msg: Any, e: Throwable) { | |
log.error("** FATAL ** " + cast(msg), e) // Keep going anyway. | |
} | |
def logInfo(msg: Any) { | |
log.info(cast(msg)) | |
} | |
def logInfo(msg: Any, e: Throwable) { | |
log.info(cast(msg), e) | |
} | |
def logTrace(msg: Any) { | |
log.trace(cast(msg)) | |
} | |
def logTrace(msg: Any, e: Throwable) { | |
log.trace(cast(msg), e) | |
} | |
def logWarn(msg: Any) { | |
log.warn(cast(msg)) | |
} | |
def logWarn(msg: Any, e: Throwable) { | |
log.warn(cast(msg), e) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment