Created
August 9, 2022 08:12
-
-
Save ghostiam/3ba7bfc40e8ad77961c8d64b113c4f21 to your computer and use it in GitHub Desktop.
Generic (Golang 1.18) version of the generator https://github.com/vektah/dataloaden
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package storage | |
import ( | |
"sync" | |
"time" | |
) | |
// LoaderConfig captures the config to create a new UserLoader | |
type LoaderConfig[Key comparable, Result any] struct { | |
// Fetch is a method that provides the data for the loader | |
Fetch func(keys []Key) ([]*Result, []error) | |
// Wait is how long wait before sending a batch | |
Wait time.Duration | |
// MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit | |
MaxBatch int | |
} | |
// NewLoader creates a new Loader given a fetch, wait, and maxBatch | |
func NewLoader[Key comparable, Result any](config LoaderConfig[Key, Result]) *Loader[Key, Result] { | |
return &Loader[Key, Result]{ | |
fetch: config.Fetch, | |
wait: config.Wait, | |
maxBatch: config.MaxBatch, | |
} | |
} | |
// Loader batches and caches requests | |
type Loader[Key comparable, Result any] struct { | |
// this method provides the data for the loader | |
fetch func(keys []Key) ([]*Result, []error) | |
// how long to done before sending a batch | |
wait time.Duration | |
// this will limit the maximum number of keys to send in one batch, 0 = no limit | |
maxBatch int | |
// INTERNAL | |
// lazily created cache | |
cache map[Key]*Result | |
// the current batch. keys will continue to be collected until timeout is hit, | |
// then everything will be sent to the fetch method and out to the listeners | |
batch *loaderBatch[Key, Result] | |
// mutex to prevent races | |
mu sync.Mutex | |
} | |
type loaderBatch[Key comparable, Result any] struct { | |
keys []Key | |
data []*Result | |
error []error | |
closing bool | |
done chan struct{} | |
} | |
// Load a data by key, batching and caching will be applied automatically | |
func (l *Loader[Key, Result]) Load(key Key) (*Result, error) { | |
return l.LoadThunk(key)() | |
} | |
// LoadThunk returns a function that when called will block waiting for a User. | |
// This method should be used if you want one goroutine to make requests to many | |
// different data loaders without blocking until the thunk is called. | |
func (l *Loader[Key, Result]) LoadThunk(key Key) func() (*Result, error) { | |
l.mu.Lock() | |
if it, ok := l.cache[key]; ok { | |
l.mu.Unlock() | |
return func() (*Result, error) { | |
return it, nil | |
} | |
} | |
if l.batch == nil { | |
l.batch = &loaderBatch[Key, Result]{done: make(chan struct{})} | |
} | |
batch := l.batch | |
pos := batch.keyIndex(l, key) | |
l.mu.Unlock() | |
return func() (*Result, error) { | |
<-batch.done | |
var data *Result | |
if pos < len(batch.data) { | |
data = batch.data[pos] | |
} | |
var err error | |
// its convenient to be able to return a single error for everything | |
if len(batch.error) == 1 { | |
err = batch.error[0] | |
} else if batch.error != nil { | |
err = batch.error[pos] | |
} | |
if err == nil { | |
l.mu.Lock() | |
l.unsafeSet(key, data) | |
l.mu.Unlock() | |
} | |
return data, err | |
} | |
} | |
// LoadAll fetches many keys at once. It will be broken into appropriate sized | |
// sub batches depending on how the loader is configured | |
func (l *Loader[Key, Result]) LoadAll(keys []Key) ([]*Result, []error) { | |
results := make([]func() (*Result, error), len(keys)) | |
for i, key := range keys { | |
results[i] = l.LoadThunk(key) | |
} | |
users := make([]*Result, len(keys)) | |
errors := make([]error, len(keys)) | |
for i, thunk := range results { | |
users[i], errors[i] = thunk() | |
} | |
return users, errors | |
} | |
// LoadAllThunk returns a function that when called will block waiting for a Users. | |
// This method should be used if you want one goroutine to make requests to many | |
// different data loaders without blocking until the thunk is called. | |
func (l *Loader[Key, Result]) LoadAllThunk(keys []Key) func() ([]*Result, []error) { | |
results := make([]func() (*Result, error), len(keys)) | |
for i, key := range keys { | |
results[i] = l.LoadThunk(key) | |
} | |
return func() ([]*Result, []error) { | |
users := make([]*Result, len(keys)) | |
errors := make([]error, len(keys)) | |
for i, thunk := range results { | |
users[i], errors[i] = thunk() | |
} | |
return users, errors | |
} | |
} | |
// Prime the cache with the provided key and value. If the key already exists, no change is made | |
// and false is returned. | |
// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) | |
func (l *Loader[Key, Result]) Prime(key Key, value *Result) bool { | |
l.mu.Lock() | |
var found bool | |
if _, found = l.cache[key]; !found { | |
// make a copy when writing to the cache, its easy to pass a pointer in from a loop var | |
// and end up with the whole cache pointing to the same value. | |
cpy := *value | |
l.unsafeSet(key, &cpy) | |
} | |
l.mu.Unlock() | |
return !found | |
} | |
// Clear the value at key from the cache, if it exists | |
func (l *Loader[Key, Result]) Clear(key Key) { | |
l.mu.Lock() | |
delete(l.cache, key) | |
l.mu.Unlock() | |
} | |
func (l *Loader[Key, Result]) unsafeSet(key Key, value *Result) { | |
if l.cache == nil { | |
l.cache = map[Key]*Result{} | |
} | |
l.cache[key] = value | |
} | |
// keyIndex will return the location of the key in the batch, if its not found | |
// it will add the key to the batch | |
func (b *loaderBatch[Key, Result]) keyIndex(l *Loader[Key, Result], key Key) int { | |
for i, existingKey := range b.keys { | |
if key == existingKey { | |
return i | |
} | |
} | |
pos := len(b.keys) | |
b.keys = append(b.keys, key) | |
if pos == 0 { | |
go b.startTimer(l) | |
} | |
if l.maxBatch != 0 && pos >= l.maxBatch-1 { | |
if !b.closing { | |
b.closing = true | |
l.batch = nil | |
go b.end(l) | |
} | |
} | |
return pos | |
} | |
func (b *loaderBatch[Key, Result]) startTimer(l *Loader[Key, Result]) { | |
time.Sleep(l.wait) | |
l.mu.Lock() | |
// we must have hit a batch limit and are already finalizing this batch | |
if b.closing { | |
l.mu.Unlock() | |
return | |
} | |
l.batch = nil | |
l.mu.Unlock() | |
b.end(l) | |
} | |
func (b *loaderBatch[Key, Result]) end(l *Loader[Key, Result]) { | |
b.data, b.error = l.fetch(b.keys) | |
close(b.done) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package storage | |
import ( | |
"fmt" | |
"strings" | |
"sync" | |
"testing" | |
"time" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
) | |
// User is some kind of database backed model | |
type User struct { | |
ID string | |
Name string | |
} | |
// NewUserLoader will collect user requests for 2 milliseconds and send them as a single batch to the fetch func | |
// normally fetch would be a database call. | |
func NewUserLoader() *Loader[string, User] { | |
return NewLoader(LoaderConfig[string, User]{ | |
Fetch: func(keys []string) ([]*User, []error) { | |
users := make([]*User, len(keys)) | |
errors := make([]error, len(keys)) | |
for i, key := range keys { | |
users[i] = &User{ID: key, Name: "user " + key} | |
} | |
return users, errors | |
}, | |
Wait: 2 * time.Millisecond, | |
MaxBatch: 100, | |
}) | |
} | |
func TestUserLoader(t *testing.T) { | |
var fetches [][]string | |
var mu sync.Mutex | |
dl := NewLoader(LoaderConfig[string, User]{ | |
Fetch: func(keys []string) ([]*User, []error) { | |
mu.Lock() | |
fetches = append(fetches, keys) | |
mu.Unlock() | |
users := make([]*User, len(keys)) | |
errors := make([]error, len(keys)) | |
for i, key := range keys { | |
if strings.HasPrefix(key, "E") { | |
errors[i] = fmt.Errorf("user not found") | |
} else { | |
users[i] = &User{ID: key, Name: "user " + key} | |
} | |
} | |
return users, errors | |
}, | |
Wait: 10 * time.Millisecond, | |
MaxBatch: 5, | |
}) | |
t.Run("fetch concurrent data", func(t *testing.T) { | |
t.Run("load user successfully", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.Load("U1") | |
require.NoError(t, err) | |
require.Equal(t, u.ID, "U1") | |
}) | |
t.Run("load failed user", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.Load("E1") | |
require.Error(t, err) | |
require.Nil(t, u) | |
}) | |
t.Run("load many users", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.LoadAll([]string{"U2", "E2", "E3", "U4"}) | |
require.Equal(t, u[0].Name, "user U2") | |
require.Equal(t, u[3].Name, "user U4") | |
require.Error(t, err[1]) | |
require.Error(t, err[2]) | |
}) | |
t.Run("load thunk", func(t *testing.T) { | |
t.Parallel() | |
thunk1 := dl.LoadThunk("U5") | |
thunk2 := dl.LoadThunk("E5") | |
u1, err1 := thunk1() | |
require.NoError(t, err1) | |
require.Equal(t, "user U5", u1.Name) | |
u2, err2 := thunk2() | |
require.Error(t, err2) | |
require.Nil(t, u2) | |
}) | |
}) | |
t.Run("it sent two batches", func(t *testing.T) { | |
mu.Lock() | |
defer mu.Unlock() | |
require.Len(t, fetches, 2) | |
assert.Len(t, fetches[0], 5) | |
assert.Len(t, fetches[1], 3) | |
}) | |
t.Run("fetch more", func(t *testing.T) { | |
t.Run("previously cached", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.Load("U1") | |
require.NoError(t, err) | |
require.Equal(t, u.ID, "U1") | |
}) | |
t.Run("load many users", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.LoadAll([]string{"U2", "U4"}) | |
require.NoError(t, err[0]) | |
require.NoError(t, err[1]) | |
require.Equal(t, u[0].Name, "user U2") | |
require.Equal(t, u[1].Name, "user U4") | |
}) | |
}) | |
t.Run("no round trips", func(t *testing.T) { | |
mu.Lock() | |
defer mu.Unlock() | |
require.Len(t, fetches, 2) | |
}) | |
t.Run("fetch partial", func(t *testing.T) { | |
t.Run("errors not in cache cache value", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.Load("E2") | |
require.Nil(t, u) | |
require.Error(t, err) | |
}) | |
t.Run("load all", func(t *testing.T) { | |
t.Parallel() | |
u, err := dl.LoadAll([]string{"U1", "U4", "E1", "U9", "U5"}) | |
require.Equal(t, u[0].ID, "U1") | |
require.Equal(t, u[1].ID, "U4") | |
require.Error(t, err[2]) | |
require.Equal(t, u[3].ID, "U9") | |
require.Equal(t, u[4].ID, "U5") | |
}) | |
}) | |
t.Run("one partial trip", func(t *testing.T) { | |
mu.Lock() | |
defer mu.Unlock() | |
require.Len(t, fetches, 3) | |
require.Len(t, fetches[2], 3) // E1 U9 E2 in some random order | |
}) | |
t.Run("primed reads dont hit the fetcher", func(t *testing.T) { | |
dl.Prime("U99", &User{ID: "U99", Name: "Primed user"}) | |
u, err := dl.Load("U99") | |
require.NoError(t, err) | |
require.Equal(t, "Primed user", u.Name) | |
require.Len(t, fetches, 3) | |
}) | |
t.Run("priming in a loop is safe", func(t *testing.T) { | |
users := []User{ | |
{ID: "Alpha", Name: "Alpha"}, | |
{ID: "Omega", Name: "Omega"}, | |
} | |
for _, user := range users { | |
dl.Prime(user.ID, &user) | |
} | |
u, err := dl.Load("Alpha") | |
require.NoError(t, err) | |
require.Equal(t, "Alpha", u.Name) | |
u, err = dl.Load("Omega") | |
require.NoError(t, err) | |
require.Equal(t, "Omega", u.Name) | |
require.Len(t, fetches, 3) | |
}) | |
t.Run("cleared results will go back to the fetcher", func(t *testing.T) { | |
dl.Clear("U99") | |
u, err := dl.Load("U99") | |
require.NoError(t, err) | |
require.Equal(t, "user U99", u.Name) | |
require.Len(t, fetches, 4) | |
}) | |
t.Run("load all thunk", func(t *testing.T) { | |
thunk1 := dl.LoadAllThunk([]string{"U5", "U6"}) | |
thunk2 := dl.LoadAllThunk([]string{"U6", "E6"}) | |
users1, err1 := thunk1() | |
require.NoError(t, err1[0]) | |
require.NoError(t, err1[1]) | |
require.Equal(t, "user U5", users1[0].Name) | |
require.Equal(t, "user U6", users1[1].Name) | |
users2, err2 := thunk2() | |
require.NoError(t, err2[0]) | |
require.Error(t, err2[1]) | |
require.Equal(t, "user U6", users2[0].Name) | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment