Skip to content

Instantly share code, notes, and snippets.

@onionhammer
Last active July 9, 2024 16:43
Show Gist options
  • Save onionhammer/bab1d9e189db7ee5bbdb7f8509f11614 to your computer and use it in GitHub Desktop.
Save onionhammer/bab1d9e189db7ee5bbdb7f8509f11614 to your computer and use it in GitHub Desktop.
using System.Diagnostics;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
namespace Actors;
public interface IBatchObserver<K> : IGrainObserver
{
/// <summary>
/// A single item has been processed
/// </summary>
Task OnItemProcessedAsync(K? result);
}
public record BatchItem<TId, T, K>(ActivityContext? Context, TId Id, T Item, IBatchObserver<K> Observer);
/// <summary>
/// Abstract batching grain, meant to be run as a stateless worker
/// </summary>
public abstract partial class BatchGrain<TId, T, K> : Grain
where TId : notnull
{
/// <summary>
/// The activity source for the batch
/// </summary>
protected abstract ActivitySource? ActivitySource { get; }
/// <summary>
/// The name of the activity for each batch
/// </summary>
protected abstract string ActivityName { get; }
/// <summary>
/// The max number of items to process in a single batch
/// </summary>
protected abstract int BatchSize { get; }
/// <summary>
/// The minimum amount of time to wait for more items before processing the batch
/// </summary>
protected abstract int MinBatchDelayMs { get; }
/// <summary>
/// Grain logger
/// </summary>
private ILogger Logger { get; }
/// <summary>
/// The maximum amount of time to wait before stopping the iteration
/// </summary>
private const int MaxBatchDelayMs = 2000;
/// <summary>
/// Cache of items to process
/// </summary>
private readonly Dictionary<TId, T> _items;
/// <summary>
/// Cache of observers to notify
/// </summary>
private readonly Dictionary<TId, IBatchObserver<K>> _observers;
/// <summary>
/// The channel to queue up tasks
/// </summary>
private readonly Channel<BatchItem<TId, T, K>> _taskChannel;
/// <summary>
/// Reusable buffer for reading from the channel
/// </summary>
private readonly List<BatchItem<TId, T, K>> _buffer;
/// <summary>
/// Activity links to parent activities
/// </summary>
private readonly List<ActivityLink> _links;
/// <summary>
/// Stopwatch for measuring the time between batches
/// </summary>
private readonly Stopwatch stopwatch = Stopwatch.StartNew();
protected BatchGrain(ILogger logger)
{
Logger = logger;
_taskChannel = Channel.CreateBounded<BatchItem<TId, T, K>>(BatchSize + (BatchSize / 2));
_items = new (capacity: BatchSize);
_observers = new (capacity: BatchSize);
_buffer = new (capacity: BatchSize);
_links = new (capacity: BatchSize);
RegisterTimer(ProcessAsync, false, TimeSpan.Zero, TimeSpan.FromTicks(1));
}
/// <summary>
/// Read from the channel and invoke the batch processor
/// </summary>
protected async Task ProcessAsync(object? finalBatch)
{
var isFinalBatch = finalBatch is true;
while (true)
{
// Clear the batch
_items.Clear();
_observers.Clear();
_buffer.Clear();
// Calculate the delay for the next batch
var batchDelay = Math.Max(MinBatchDelayMs - (int)stopwatch.ElapsedMilliseconds, 0);
// Read items from batch
await ReadManyFromChannel(_taskChannel.Reader, _buffer, BatchSize, batchDelay, MaxBatchDelayMs);
foreach (var (context, id, item, observer) in _buffer)
{
_items.Add(id, item);
_observers.Add(id, observer);
if (context.HasValue)
_links.Add(new ActivityLink(context.Value));
}
// Reset the stopwatch
stopwatch.Restart();
if (_items.Count == 0)
break;
// Start an activity for the batch
using var activity = ActivitySource?
.StartActivity(ActivityName, ActivityKind.Internal,
parentContext: default,
links: _links
);
activity?.SetTag("BatchSize", _items.Count);
try
{
var resultItems = await ProcessBatchAsync(_items);
// Map the results back to the observers
await Parallel.ForEachAsync(_observers, (update, cancellationToken) =>
{
if (resultItems.TryGetValue(update.Key, out var result))
{
var task = update.Value.OnItemProcessedAsync(result).ContinueWith(t =>
{
if (t.IsFaulted)
LogFailedToProcessItem(Logger, t.Exception, ActivityName, update.Key.ToString());
}, CancellationToken.None);
return new ValueTask(task);
}
return default;
});
// If this is the final batch, continue processing until the channel is empty
if (isFinalBatch && await _taskChannel.Reader.WaitToReadAsync())
continue;
break;
}
catch (Exception ex)
{
activity?.SetStatus(ActivityStatusCode.Error, ex.Message);
LogFailedToProcessBatch(Logger, _items.Count);
}
}
}
/// <summary>
/// Process a batch of items
/// </summary>
protected abstract Task<IReadOnlyDictionary<TId, K>> ProcessBatchAsync(IReadOnlyDictionary<TId, T> items);
/// <summary>
/// Add a task to the queue -- Can be interleaved
/// </summary>
protected async Task<bool> TryAddBatchItem(TId id, T item, IBatchObserver<K> observer)
{
var context = Activity.Current?.Context;
if (await _taskChannel.Writer.WaitToWriteAsync())
return _taskChannel.Writer.TryWrite(new (context, id, item, observer));
return false;
}
public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken)
{
// Cancel the task processors
_taskChannel.Writer.Complete();
// Wait for all batch tasks to complete before deactivating the grain
await ProcessAsync(true);
}
/// <summary>
/// Reads a batch of items from the channel, waiting for data to become available.
/// </summary>
/// <param name="reader">The channel reader</param>
/// <param name="batchSize">The number of items to read</param>
/// <param name="minDelayMs">The max amount of time to wait for at least 1 item to queue</param>
/// <param name="cancellationToken">The cancellation token</param>
private static async Task ReadManyFromChannel(
ChannelReader<BatchItem<TId, T, K>> reader,
List<BatchItem<TId, T, K>> buffer,
int batchSize,
int minDelayMs,
int maxDelayMs,
CancellationToken cancellationToken = default)
{
var stopwatch = Stopwatch.StartNew();
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cancellationTokenSource.CancelAfter(maxDelayMs);
int count = 0;
var token = cancellationTokenSource.Token;
try
{
while (!token.IsCancellationRequested)
{
// Wait for data to become available
if (!await reader.WaitToReadAsync(token)) // No data will ever be available to read from this channel again
break;
// Read the data
while (reader.TryRead(out var item))
{
buffer.Add(item);
if (++count >= batchSize)
break;
}
// If minDelay has not passed, wait for more data
if (stopwatch.ElapsedMilliseconds < minDelayMs)
continue;
// Return the data
break;
}
}
catch (OperationCanceledException)
{
// Ignore the exception
}
}
[LoggerMessage(EventId = 2, Level = LogLevel.Critical, Message = "{Activity} Failed to process item {Id}")]
public static partial void LogFailedToProcessItem(ILogger logger, Exception ex, string activity, string? id);
[LoggerMessage(EventId = 3, Level = LogLevel.Information, Message = "Failed to process batch of {Count} items")]
public static partial void LogFailedToProcessBatch(ILogger logger, int count);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment