Skip to content

Instantly share code, notes, and snippets.

@julik
Last active April 28, 2025 09:31
Show Gist options
  • Save julik/69066f5a819ac3b38480d42c1351f8ef to your computer and use it in GitHub Desktop.
Save julik/69066f5a819ac3b38480d42c1351f8ef to your computer and use it in GitHub Desktop.
require "digest"
require "rack"
# This class encapsulates a unit of work done for a particular tenant, connected to that tenant's database.
# ActiveRecord makes it _very_ hard to do in a simple manner and clever stuff is required, but it is knowable.
#
# What this class provides is a "misuse" of the database "roles" of ActiveRecord to have a role per tenant.
# If all the tenants are predefined, it can be done roughly so:
#
# ActiveRecord::Base.legacy_connection_handling = false if ActiveRecord::Base.respond_to?(:legacy_connection_handling)
# $databases.each_pair do |n, db_path|
# config_hash = {
# "adapter" => 'sqlite3',
# "database" => db_path,
# "pool" => 4
# }
# ActiveRecord::Base.connection_handler.establish_connection(config_hash, role: "database_#{n}")
# end
#
# def named_databases_as_roles_using_connected_to(n, from_database_paths)
# ActiveRecord::Base.connected_to(role: "database_#{n}") do
# query_and_compare!(n)
# end
# end
#
# So what we do is this:
#
# * We want one connection pool per tenant (per database, thus)
# * We want to grab a connection from that pool and make sure our queries use that connection
# * Once we are done with our unit of work we want to return the connection to the pool
#
# This also uses a stack of Fibers because `connected_to` in ActiveRecord _wants_ to have a block, but for us
# "leaving" the context of a unit of work can happen in a Rack body close() call.
class Shardine
class Middleware
def initialize(app, &database_config_lookup)
@app = app
@lookup = database_config_lookup
end
def call(env)
switcher = Shardine.new(connection_config_hash: @lookup.call(env))
did_enter = switcher.enter!
status, headers, body = @app.call(env)
body_with_close = Rack::BodyProxy.new(body) { switcher.leave! }
[status, headers, body_with_close]
rescue
switcher.leave! if did_enter
raise
end
end
CONNECTION_MANAGEMENT_MUTEX = Mutex.new
def initialize(connection_config_hash:)
if ActiveRecord::Base.respond_to?(:legacy_connection_handling) && ActiveRecord::Base.legacy_connection_handling
raise ArgumentError, "ActiveRecord::Base.legacy_connection_handling is enabled (set to `true`) and we can't use roles that way."
end
@config = connection_config_hash.to_h.with_indifferent_access
@role_name = "shardine_#{@config.fetch(:database)}"
end
def with(&blk)
create_pool_if_none!
ActiveRecord::Base.connected_to(role: @role_name, &blk)
end
def enter!
@fiber = Fiber.new do
create_pool_if_none!
ActiveRecord::Base.connected_to(role: @role_name) do
Fiber.yield
end
end
@fiber.resume
true
end
def leave!
to_resume, @fiber = @fiber, nil
to_resume&.resume
true
end
def create_pool_if_none!
# Create a connection pool for that tenant if it doesn't exist
CONNECTION_MANAGEMENT_MUTEX.synchronize do
if ActiveRecord::Base.connection_handler.connection_pool_list(@role_name).none?
ActiveRecord::Base.connection_handler.establish_connection(@config, role: @role_name)
end
end
end
end
# # Use it like so:
# use Shardine::Middleware do |env|
# site_name = env["SERVER_NAME"]
# {adapter: "sqlite3", database: "sites/#{site_name}.sqlite3"}
# end
require "bundler"
Bundler.setup
require "logger"
require "active_record"
require "minitest"
require "minitest/autorun"
require "sqlite3"
require_relative "../lib/shardine"
class ShardineTest < Minitest::Test
N_THREADS = 12
def setup
@test_dir = "shardine-#{Process.pid}-#{Minitest.seed}"
FileUtils.mkdir_p(@test_dir)
# Set up the test databases (without using ActiveRecord)
16.times do |n|
file = File.join(@test_dir, "#{n}.sqlite3")
SQLite3::Database.open(file) do |db|
db.execute("CREATE TABLE some_values (id INTEGER PRIMARY KEY AUTOINCREMENT, val INTEGER)")
n.times do
db.execute("INSERT INTO some_values (val) VALUES (?)", [n])
end
end
end
@databases = Dir.glob(File.join(@test_dir, "*.sqlite3")).sort.map do |path|
n = SQLite3::Database.open(path) do |db|
db.get_first_value("SELECT COUNT(*) FROM some_values")
end
[n, path]
end.to_h
end
class SomeValue < ActiveRecord::Base
self.table_name = "some_values"
end
def teardown
FileUtils.rm_rf(@test_dir)
end
def test_fails_with_legacy_connection_handling
# This is only relevant with Rails 6
skip unless ActiveRecord::Base.respond_to?(:legacy_connection_handling=)
ActiveRecord::Base.legacy_connection_handling = true
config = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(0)
}
assert_raises(ArgumentError) do
ctx = Shardine.new(connection_config_hash: config)
raise "should not get here"
end
end
def test_sequential_switching
disable_legacy_connection_handling!
rng = Random.new(Minitest.seed)
16.times do
n = @databases.keys.sample(random: rng)
config = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(n),
"pool" => N_THREADS + 1 # Needs to be set because these pools may get reused by `test_threaded_switching`, depending on test order
}
ctx = Shardine.new(connection_config_hash: config)
ctx.with do
assert_correct_database_used(n)
end
end
end
def test_enter_and_leave
disable_legacy_connection_handling!
config_1 = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(1),
"pool" => N_THREADS + 1
}
config_2 = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(2),
"pool" => N_THREADS + 1
}
ctx1 = Shardine.new(connection_config_hash: config_1)
ctx2 = Shardine.new(connection_config_hash: config_2)
assert ctx1.enter!
assert_correct_database_used(1)
assert ctx2.enter!
assert_correct_database_used(2)
assert ctx2.leave!
assert_correct_database_used(1)
assert ctx1.leave!
assert_raises(ActiveRecord::ConnectionNotEstablished) do
assert_correct_database_used(0)
end
end
def test_middleware
disable_legacy_connection_handling!
rng = Random.new(Minitest.seed)
app_called_n_times = 0
8.times do
n = @databases.keys.sample(random: rng)
app = ->(env) {
app_called_n_times += 1
assert_correct_database_used(n)
[200, {}, ["Database #{n}"]]
}
middleware = Shardine::Middleware.new(app) do
config = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(n),
"pool" => N_THREADS + 1 # Needs to be set because these pools may get reused by `test_threaded_switching`, depending on test order
}
end
status, headers, body = middleware.call({})
assert body.respond_to?(:close)
body.close
end
assert_equal 8, app_called_n_times
end
def test_threaded_switching
disable_legacy_connection_handling!
8.times do
flow_iterations = 32
threads = N_THREADS.times.map do
Thread.new do
rng = Random.new(Minitest.seed)
flow_iterations.times do
n = @databases.keys.sample(random: rng)
config = {
"adapter" => 'sqlite3',
"database" => @databases.fetch(n),
"pool" => N_THREADS + 1
}
ctx = Shardine.new(connection_config_hash: config)
ctx.with do
assert_correct_database_used(n)
end
end
:ok
end
end
values = threads.map(&:join).map(&:value)
assert_equal [:ok], values.uniq
end
end
def disable_legacy_connection_handling!
ActiveRecord::Base.legacy_connection_handling = false if ActiveRecord::Base.respond_to?(:legacy_connection_handling=)
end
def assert_correct_database_used(n)
num_rows = SomeValue.count
assert_equal n, num_rows, "Mismatch: expected to have queried DB #{n} but we queried #{num_rows} instead"
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment