Created
November 19, 2015 19:24
-
-
Save shashir/c6b768b660fdd73ceed6 to your computer and use it in GitHub Desktop.
Scala JDBC querying (useful for building basic ORM)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.Connection | |
import java.sql.DriverManager | |
import java.sql.ResultSet | |
import java.sql.Statement | |
import scala.annotation.tailrec | |
import com.typesafe.scalalogging.slf4j.Logger | |
import org.slf4j.LoggerFactory | |
/** | |
* Abstract class defining MySQL downloads. | |
* | |
* @param server address. | |
* @param user string. | |
* @param password string. | |
* @tparam T type of object to parse from the result set. | |
*/ | |
abstract class MySQLDownloader[T]( | |
server: String, | |
user: String, | |
password: String, | |
retries: Int | |
) { | |
import MySQLDownloader._ | |
LOG.info("Loading MySQL driver...") | |
Class.forName(JDBC_MYSQL_DRIVER) | |
LOG.info("Connecting to database...") | |
val connection: Connection = DriverManager.getConnection( | |
JDBC_MYSQL_PREFIX_FORMAT.format(server), | |
user, | |
password | |
) | |
val statement: Statement = connection.createStatement() | |
/** | |
* Download rows between specified ids. | |
* | |
* @param minId minimum id to download. | |
* @param limitId maximum id (exclusive) up to which to download. | |
* @param step how many id's at a time to download. | |
* @param agg aggregation of rows into specified types. | |
* @return parsed rows of specific type. | |
*/ | |
def apply( | |
minId: Long, | |
limitId: Long, | |
step: Long, | |
agg: Seq[T] = Seq() | |
): Seq[T] = { | |
require(minId < limitId, "minId must be less than maxId") | |
require(step > 0, "Id step must be positive") | |
return (minId until limitId by step).view.flatMap { id: Long => | |
LOG.info(s"Querying ids from $id to ${id + step}...") | |
guaranteedExecution( | |
statement, | |
queryString(id, Math.min(limitId, id + step)), | |
retries | |
).map { | |
resultSet: ResultSet => parseResultSet(resultSet) | |
}.getOrElse(Seq()) | |
} | |
} | |
/** | |
* Produce query string. | |
* | |
* @param minId minimum id to download. | |
* @param limitId maximum id (exclusive) up to which to download. | |
* @return query string. | |
*/ | |
def queryString(minId: Long, limitId: Long): String | |
/** | |
* Convert result set to sequence of parsed rows. | |
* | |
* @param resultSet produced by query. | |
* @return sequence of parsed rows. | |
*/ | |
def parseResultSet(resultSet: ResultSet): Seq[T] = this.recursivelyParseResultSet( | |
resultSet, | |
Seq.empty[T] | |
) | |
/** | |
* Extract a single row from the result set. Assume the result set is guaranteed to be non-empty. | |
* | |
* @param resultSet produced by query. | |
* @return a row of data. | |
*/ | |
def parseRow(resultSet: ResultSet): T | |
/** | |
* Use [[parseRow()]] to recursively extract all rows from the result set. | |
* | |
* @param resultSet produced by query. | |
* @param agg aggregation of all previously gathered rows. | |
* @return sequence of rows produced by data. | |
*/ | |
private def recursivelyParseResultSet(resultSet: ResultSet, agg: Seq[T]): Seq[T] = { | |
if (resultSet.next()) { | |
val row = parseRow(resultSet) | |
return recursivelyParseResultSet(resultSet, agg.+:(row)) | |
} else { | |
LOG.info(s"Retrieved ${agg.size} rows.") | |
return agg.reverse | |
} | |
} | |
/** | |
* Close database connection. | |
*/ | |
def close(): Unit = { | |
LOG.info("Closing database connection...") | |
statement.close() | |
connection.close() | |
} | |
} | |
/** | |
* Companion. | |
*/ | |
object MySQLDownloader { | |
val LOG = Logger(LoggerFactory.getLogger(this.getClass.getName.split("\\$")(0))) | |
val JDBC_MYSQL_DRIVER: String = "com.mysql.jdbc.Driver" | |
val JDBC_MYSQL_PREFIX_FORMAT: String = "jdbc:mysql://%s/?autoReconnect=true" | |
/** | |
* Guaranteed execution of SQL statement. On fail retry up to specified number of times. | |
* | |
* @param statement SQL statement context. | |
* @param query to execute. | |
* @param retries number of retries. | |
* @return optional resultset. | |
*/ | |
@tailrec def guaranteedExecution( | |
statement: Statement, | |
query: String, | |
retries: Int = Int.MaxValue | |
): Option[ResultSet] = { | |
try { | |
return Some(statement.executeQuery(query)) | |
} catch { | |
case exception: Exception => { | |
LOG.info(s"Execution of ${statement} failed... printing stacktrace") | |
LOG.info(exception.getStackTraceString) | |
if (retries > 0) { | |
LOG.info(s"Retrying... (${retries} retries left)") | |
return guaranteedExecution(statement, query, retries - 1) | |
} else { | |
LOG.info(s"No retries left. Skipping query.") | |
return None | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment