Skip to content

Instantly share code, notes, and snippets.

@lispyclouds
Last active November 30, 2024 14:24
Show Gist options
  • Save lispyclouds/ba87671c05616f6a1bcd5ae36ce6a4be to your computer and use it in GitHub Desktop.
Save lispyclouds/ba87671c05616f6a1bcd5ae36ce6a4be to your computer and use it in GitHub Desktop.
Go gather tasks with max concurrency control
package main
import (
"context"
"encoding/json"
"fmt"
"iter"
"net/http"
"sync"
"time"
"github.com/spf13/cobra"
"golang.org/x/sync/semaphore"
)
type Result[T any] struct {
Val T
Err error
}
type Task[T any] func(ctx context.Context) (T, error)
func worker[T any](ctx context.Context, wg *sync.WaitGroup, in <-chan Task[T], out chan<- Result[T]) {
for task := range in {
res, err := task(ctx)
out <- Result[T]{Val: res, Err: err}
wg.Done()
}
}
func gatherWorker[T any](ctx context.Context, maxConcurrency int, tasks ...Task[T]) iter.Seq[Result[T]] {
return func(yield func(Result[T]) bool) {
var wg sync.WaitGroup
send := make(chan Task[T])
recv := make(chan Result[T], len(tasks))
for range maxConcurrency {
go worker(ctx, &wg, send, recv)
}
for _, task := range tasks {
send <- task
wg.Add(1)
}
wg.Wait()
close(send)
close(recv)
for result := range recv {
if !yield(result) {
return
}
}
}
}
func gatherSem[T any](ctx context.Context, maxConcurrency int, tasks ...Task[T]) iter.Seq[Result[T]] {
return func(yield func(Result[T]) bool) {
sem := semaphore.NewWeighted(int64(maxConcurrency))
recv := make(chan Result[T], len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
sem.Acquire(ctx, 1)
wg.Add(1)
go func() {
res, err := task(ctx)
recv <- Result[T]{Val: res, Err: err}
sem.Release(1)
wg.Done()
}()
}
wg.Wait()
close(recv)
for result := range recv {
if !yield(result) {
return
}
}
}
}
func playMain(cmd *cobra.Command, _ []string) {
apiUrl := "http://www.randomnumberapi.com/api/v1.0/random?min=1&max=100" // returns an array of ints
taskCount, _ := cmd.Flags().GetInt("tasks")
maxConcurrency, _ := cmd.Flags().GetInt("max-concurrency")
taskDelay, _ := cmd.Flags().GetInt("task-delay")
tasks := []Task[int64]{}
for range taskCount {
tasks = append(tasks, func(ctx context.Context) (int64, error) {
req, err := http.NewRequestWithContext(ctx, "GET", apiUrl, nil)
if err != nil {
return 0, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return 0, err
}
var nums []int64
dec := json.NewDecoder(res.Body)
if err = dec.Decode(&nums); err != nil {
return 0, err
}
n := nums[0]
fmt.Printf("Adding %d\n", n)
time.Sleep(time.Duration(taskDelay) * time.Second)
return n, nil
})
}
if maxConcurrency == 0 {
maxConcurrency = len(tasks)
}
var sum int64
for result := range gatherSem(context.Background(), maxConcurrency, tasks...) {
if err := result.Err; err != nil {
fmt.Printf("Error: %s\n", err.Error())
return
}
sum += result.Val
}
fmt.Printf("Sum is: %d\n", sum)
}
func main() {
cmd := &cobra.Command{
Use: "goplay",
Short: "goplay CLI",
Long: "CLI for GoPlay",
Run: playMain,
}
cmd.Flags().IntP("tasks", "t", 1, "number of tasks to spawn")
cmd.Flags().IntP("max-concurrency", "m", 0, "limit number of tasks to run concurrently, skip for unlimited")
cmd.Flags().IntP("task-delay", "d", 1, "added delay in seconds to each task")
cmd.Execute()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment