Last active
July 9, 2024 16:43
-
-
Save onionhammer/bab1d9e189db7ee5bbdb7f8509f11614 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Diagnostics; | |
using System.Threading.Channels; | |
using Microsoft.Extensions.Logging; | |
namespace Actors; | |
public interface IBatchObserver<K> : IGrainObserver | |
{ | |
/// <summary> | |
/// A single item has been processed | |
/// </summary> | |
Task OnItemProcessedAsync(K? result); | |
} | |
public record BatchItem<TId, T, K>(ActivityContext? Context, TId Id, T Item, IBatchObserver<K> Observer); | |
/// <summary> | |
/// Abstract batching grain, meant to be run as a stateless worker | |
/// </summary> | |
public abstract partial class BatchGrain<TId, T, K> : Grain | |
where TId : notnull | |
{ | |
/// <summary> | |
/// The activity source for the batch | |
/// </summary> | |
protected abstract ActivitySource? ActivitySource { get; } | |
/// <summary> | |
/// The name of the activity for each batch | |
/// </summary> | |
protected abstract string ActivityName { get; } | |
/// <summary> | |
/// The max number of items to process in a single batch | |
/// </summary> | |
protected abstract int BatchSize { get; } | |
/// <summary> | |
/// The minimum amount of time to wait for more items before processing the batch | |
/// </summary> | |
protected abstract int MinBatchDelayMs { get; } | |
/// <summary> | |
/// Grain logger | |
/// </summary> | |
private ILogger Logger { get; } | |
/// <summary> | |
/// The maximum amount of time to wait before stopping the iteration | |
/// </summary> | |
private const int MaxBatchDelayMs = 2000; | |
/// <summary> | |
/// Cache of items to process | |
/// </summary> | |
private readonly Dictionary<TId, T> _items; | |
/// <summary> | |
/// Cache of observers to notify | |
/// </summary> | |
private readonly Dictionary<TId, IBatchObserver<K>> _observers; | |
/// <summary> | |
/// The channel to queue up tasks | |
/// </summary> | |
private readonly Channel<BatchItem<TId, T, K>> _taskChannel; | |
/// <summary> | |
/// Reusable buffer for reading from the channel | |
/// </summary> | |
private readonly List<BatchItem<TId, T, K>> _buffer; | |
/// <summary> | |
/// Activity links to parent activities | |
/// </summary> | |
private readonly List<ActivityLink> _links; | |
/// <summary> | |
/// Stopwatch for measuring the time between batches | |
/// </summary> | |
private readonly Stopwatch stopwatch = Stopwatch.StartNew(); | |
protected BatchGrain(ILogger logger) | |
{ | |
Logger = logger; | |
_taskChannel = Channel.CreateBounded<BatchItem<TId, T, K>>(BatchSize + (BatchSize / 2)); | |
_items = new (capacity: BatchSize); | |
_observers = new (capacity: BatchSize); | |
_buffer = new (capacity: BatchSize); | |
_links = new (capacity: BatchSize); | |
RegisterTimer(ProcessAsync, false, TimeSpan.Zero, TimeSpan.FromTicks(1)); | |
} | |
/// <summary> | |
/// Read from the channel and invoke the batch processor | |
/// </summary> | |
protected async Task ProcessAsync(object? finalBatch) | |
{ | |
var isFinalBatch = finalBatch is true; | |
while (true) | |
{ | |
// Clear the batch | |
_items.Clear(); | |
_observers.Clear(); | |
_buffer.Clear(); | |
// Calculate the delay for the next batch | |
var batchDelay = Math.Max(MinBatchDelayMs - (int)stopwatch.ElapsedMilliseconds, 0); | |
// Read items from batch | |
await ReadManyFromChannel(_taskChannel.Reader, _buffer, BatchSize, batchDelay, MaxBatchDelayMs); | |
foreach (var (context, id, item, observer) in _buffer) | |
{ | |
_items.Add(id, item); | |
_observers.Add(id, observer); | |
if (context.HasValue) | |
_links.Add(new ActivityLink(context.Value)); | |
} | |
// Reset the stopwatch | |
stopwatch.Restart(); | |
if (_items.Count == 0) | |
break; | |
// Start an activity for the batch | |
using var activity = ActivitySource? | |
.StartActivity(ActivityName, ActivityKind.Internal, | |
parentContext: default, | |
links: _links | |
); | |
activity?.SetTag("BatchSize", _items.Count); | |
try | |
{ | |
var resultItems = await ProcessBatchAsync(_items); | |
// Map the results back to the observers | |
await Parallel.ForEachAsync(_observers, (update, cancellationToken) => | |
{ | |
if (resultItems.TryGetValue(update.Key, out var result)) | |
{ | |
var task = update.Value.OnItemProcessedAsync(result).ContinueWith(t => | |
{ | |
if (t.IsFaulted) | |
LogFailedToProcessItem(Logger, t.Exception, ActivityName, update.Key.ToString()); | |
}, CancellationToken.None); | |
return new ValueTask(task); | |
} | |
return default; | |
}); | |
// If this is the final batch, continue processing until the channel is empty | |
if (isFinalBatch && await _taskChannel.Reader.WaitToReadAsync()) | |
continue; | |
break; | |
} | |
catch (Exception ex) | |
{ | |
activity?.SetStatus(ActivityStatusCode.Error, ex.Message); | |
LogFailedToProcessBatch(Logger, _items.Count); | |
} | |
} | |
} | |
/// <summary> | |
/// Process a batch of items | |
/// </summary> | |
protected abstract Task<IReadOnlyDictionary<TId, K>> ProcessBatchAsync(IReadOnlyDictionary<TId, T> items); | |
/// <summary> | |
/// Add a task to the queue -- Can be interleaved | |
/// </summary> | |
protected async Task<bool> TryAddBatchItem(TId id, T item, IBatchObserver<K> observer) | |
{ | |
var context = Activity.Current?.Context; | |
if (await _taskChannel.Writer.WaitToWriteAsync()) | |
return _taskChannel.Writer.TryWrite(new (context, id, item, observer)); | |
return false; | |
} | |
public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken) | |
{ | |
// Cancel the task processors | |
_taskChannel.Writer.Complete(); | |
// Wait for all batch tasks to complete before deactivating the grain | |
await ProcessAsync(true); | |
} | |
/// <summary> | |
/// Reads a batch of items from the channel, waiting for data to become available. | |
/// </summary> | |
/// <param name="reader">The channel reader</param> | |
/// <param name="batchSize">The number of items to read</param> | |
/// <param name="minDelayMs">The max amount of time to wait for at least 1 item to queue</param> | |
/// <param name="cancellationToken">The cancellation token</param> | |
private static async Task ReadManyFromChannel( | |
ChannelReader<BatchItem<TId, T, K>> reader, | |
List<BatchItem<TId, T, K>> buffer, | |
int batchSize, | |
int minDelayMs, | |
int maxDelayMs, | |
CancellationToken cancellationToken = default) | |
{ | |
var stopwatch = Stopwatch.StartNew(); | |
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | |
cancellationTokenSource.CancelAfter(maxDelayMs); | |
int count = 0; | |
var token = cancellationTokenSource.Token; | |
try | |
{ | |
while (!token.IsCancellationRequested) | |
{ | |
// Wait for data to become available | |
if (!await reader.WaitToReadAsync(token)) // No data will ever be available to read from this channel again | |
break; | |
// Read the data | |
while (reader.TryRead(out var item)) | |
{ | |
buffer.Add(item); | |
if (++count >= batchSize) | |
break; | |
} | |
// If minDelay has not passed, wait for more data | |
if (stopwatch.ElapsedMilliseconds < minDelayMs) | |
continue; | |
// Return the data | |
break; | |
} | |
} | |
catch (OperationCanceledException) | |
{ | |
// Ignore the exception | |
} | |
} | |
[LoggerMessage(EventId = 2, Level = LogLevel.Critical, Message = "{Activity} Failed to process item {Id}")] | |
public static partial void LogFailedToProcessItem(ILogger logger, Exception ex, string activity, string? id); | |
[LoggerMessage(EventId = 3, Level = LogLevel.Information, Message = "Failed to process batch of {Count} items")] | |
public static partial void LogFailedToProcessBatch(ILogger logger, int count); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment