Last active
April 28, 2025 09:31
-
-
Save julik/69066f5a819ac3b38480d42c1351f8ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "digest" | |
require "rack" | |
# This class encapsulates a unit of work done for a particular tenant, connected to that tenant's database. | |
# ActiveRecord makes it _very_ hard to do in a simple manner and clever stuff is required, but it is knowable. | |
# | |
# What this class provides is a "misuse" of the database "roles" of ActiveRecord to have a role per tenant. | |
# If all the tenants are predefined, it can be done roughly so: | |
# | |
# ActiveRecord::Base.legacy_connection_handling = false if ActiveRecord::Base.respond_to?(:legacy_connection_handling) | |
# $databases.each_pair do |n, db_path| | |
# config_hash = { | |
# "adapter" => 'sqlite3', | |
# "database" => db_path, | |
# "pool" => 4 | |
# } | |
# ActiveRecord::Base.connection_handler.establish_connection(config_hash, role: "database_#{n}") | |
# end | |
# | |
# def named_databases_as_roles_using_connected_to(n, from_database_paths) | |
# ActiveRecord::Base.connected_to(role: "database_#{n}") do | |
# query_and_compare!(n) | |
# end | |
# end | |
# | |
# So what we do is this: | |
# | |
# * We want one connection pool per tenant (per database, thus) | |
# * We want to grab a connection from that pool and make sure our queries use that connection | |
# * Once we are done with our unit of work we want to return the connection to the pool | |
# | |
# This also uses a stack of Fibers because `connected_to` in ActiveRecord _wants_ to have a block, but for us | |
# "leaving" the context of a unit of work can happen in a Rack body close() call. | |
class Shardine | |
class Middleware | |
def initialize(app, &database_config_lookup) | |
@app = app | |
@lookup = database_config_lookup | |
end | |
def call(env) | |
switcher = Shardine.new(connection_config_hash: @lookup.call(env)) | |
did_enter = switcher.enter! | |
status, headers, body = @app.call(env) | |
body_with_close = Rack::BodyProxy.new(body) { switcher.leave! } | |
[status, headers, body_with_close] | |
rescue | |
switcher.leave! if did_enter | |
raise | |
end | |
end | |
CONNECTION_MANAGEMENT_MUTEX = Mutex.new | |
def initialize(connection_config_hash:) | |
if ActiveRecord::Base.respond_to?(:legacy_connection_handling) && ActiveRecord::Base.legacy_connection_handling | |
raise ArgumentError, "ActiveRecord::Base.legacy_connection_handling is enabled (set to `true`) and we can't use roles that way." | |
end | |
@config = connection_config_hash.to_h.with_indifferent_access | |
@role_name = "shardine_#{@config.fetch(:database)}" | |
end | |
def with(&blk) | |
create_pool_if_none! | |
ActiveRecord::Base.connected_to(role: @role_name, &blk) | |
end | |
def enter! | |
@fiber = Fiber.new do | |
create_pool_if_none! | |
ActiveRecord::Base.connected_to(role: @role_name) do | |
Fiber.yield | |
end | |
end | |
@fiber.resume | |
true | |
end | |
def leave! | |
to_resume, @fiber = @fiber, nil | |
to_resume&.resume | |
true | |
end | |
def create_pool_if_none! | |
# Create a connection pool for that tenant if it doesn't exist | |
CONNECTION_MANAGEMENT_MUTEX.synchronize do | |
if ActiveRecord::Base.connection_handler.connection_pool_list(@role_name).none? | |
ActiveRecord::Base.connection_handler.establish_connection(@config, role: @role_name) | |
end | |
end | |
end | |
end | |
# # Use it like so: | |
# use Shardine::Middleware do |env| | |
# site_name = env["SERVER_NAME"] | |
# {adapter: "sqlite3", database: "sites/#{site_name}.sqlite3"} | |
# end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "bundler" | |
Bundler.setup | |
require "logger" | |
require "active_record" | |
require "minitest" | |
require "minitest/autorun" | |
require "sqlite3" | |
require_relative "../lib/shardine" | |
class ShardineTest < Minitest::Test | |
N_THREADS = 12 | |
def setup | |
@test_dir = "shardine-#{Process.pid}-#{Minitest.seed}" | |
FileUtils.mkdir_p(@test_dir) | |
# Set up the test databases (without using ActiveRecord) | |
16.times do |n| | |
file = File.join(@test_dir, "#{n}.sqlite3") | |
SQLite3::Database.open(file) do |db| | |
db.execute("CREATE TABLE some_values (id INTEGER PRIMARY KEY AUTOINCREMENT, val INTEGER)") | |
n.times do | |
db.execute("INSERT INTO some_values (val) VALUES (?)", [n]) | |
end | |
end | |
end | |
@databases = Dir.glob(File.join(@test_dir, "*.sqlite3")).sort.map do |path| | |
n = SQLite3::Database.open(path) do |db| | |
db.get_first_value("SELECT COUNT(*) FROM some_values") | |
end | |
[n, path] | |
end.to_h | |
end | |
class SomeValue < ActiveRecord::Base | |
self.table_name = "some_values" | |
end | |
def teardown | |
FileUtils.rm_rf(@test_dir) | |
end | |
def test_fails_with_legacy_connection_handling | |
# This is only relevant with Rails 6 | |
skip unless ActiveRecord::Base.respond_to?(:legacy_connection_handling=) | |
ActiveRecord::Base.legacy_connection_handling = true | |
config = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(0) | |
} | |
assert_raises(ArgumentError) do | |
ctx = Shardine.new(connection_config_hash: config) | |
raise "should not get here" | |
end | |
end | |
def test_sequential_switching | |
disable_legacy_connection_handling! | |
rng = Random.new(Minitest.seed) | |
16.times do | |
n = @databases.keys.sample(random: rng) | |
config = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(n), | |
"pool" => N_THREADS + 1 # Needs to be set because these pools may get reused by `test_threaded_switching`, depending on test order | |
} | |
ctx = Shardine.new(connection_config_hash: config) | |
ctx.with do | |
assert_correct_database_used(n) | |
end | |
end | |
end | |
def test_enter_and_leave | |
disable_legacy_connection_handling! | |
config_1 = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(1), | |
"pool" => N_THREADS + 1 | |
} | |
config_2 = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(2), | |
"pool" => N_THREADS + 1 | |
} | |
ctx1 = Shardine.new(connection_config_hash: config_1) | |
ctx2 = Shardine.new(connection_config_hash: config_2) | |
assert ctx1.enter! | |
assert_correct_database_used(1) | |
assert ctx2.enter! | |
assert_correct_database_used(2) | |
assert ctx2.leave! | |
assert_correct_database_used(1) | |
assert ctx1.leave! | |
assert_raises(ActiveRecord::ConnectionNotEstablished) do | |
assert_correct_database_used(0) | |
end | |
end | |
def test_middleware | |
disable_legacy_connection_handling! | |
rng = Random.new(Minitest.seed) | |
app_called_n_times = 0 | |
8.times do | |
n = @databases.keys.sample(random: rng) | |
app = ->(env) { | |
app_called_n_times += 1 | |
assert_correct_database_used(n) | |
[200, {}, ["Database #{n}"]] | |
} | |
middleware = Shardine::Middleware.new(app) do | |
config = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(n), | |
"pool" => N_THREADS + 1 # Needs to be set because these pools may get reused by `test_threaded_switching`, depending on test order | |
} | |
end | |
status, headers, body = middleware.call({}) | |
assert body.respond_to?(:close) | |
body.close | |
end | |
assert_equal 8, app_called_n_times | |
end | |
def test_threaded_switching | |
disable_legacy_connection_handling! | |
8.times do | |
flow_iterations = 32 | |
threads = N_THREADS.times.map do | |
Thread.new do | |
rng = Random.new(Minitest.seed) | |
flow_iterations.times do | |
n = @databases.keys.sample(random: rng) | |
config = { | |
"adapter" => 'sqlite3', | |
"database" => @databases.fetch(n), | |
"pool" => N_THREADS + 1 | |
} | |
ctx = Shardine.new(connection_config_hash: config) | |
ctx.with do | |
assert_correct_database_used(n) | |
end | |
end | |
:ok | |
end | |
end | |
values = threads.map(&:join).map(&:value) | |
assert_equal [:ok], values.uniq | |
end | |
end | |
def disable_legacy_connection_handling! | |
ActiveRecord::Base.legacy_connection_handling = false if ActiveRecord::Base.respond_to?(:legacy_connection_handling=) | |
end | |
def assert_correct_database_used(n) | |
num_rows = SomeValue.count | |
assert_equal n, num_rows, "Mismatch: expected to have queried DB #{n} but we queried #{num_rows} instead" | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment