-
-
Save NDiiong/8e6b2c8ff2eb11cd618ed9dccc134111 to your computer and use it in GitHub Desktop.
Task extensions for parallelism
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Threading; | |
using System.Threading.Tasks; | |
namespace Gist | |
{ | |
public static class TaskExtensions | |
{ | |
/// <summary> | |
/// Creates a task that will complete when all of the <see cref="Task{TResult}"/> objects in an enumerable collection have completed. | |
/// It runs only specified number of tasks at most in parallel. | |
/// </summary> | |
/// <param name="tasks">The tasks to wait on for completion.</param> | |
/// <param name="maxConcurrency">The maximum number of tasks to run at the same time.</param> | |
/// <remarks> | |
/// It will skip a null in the tasks, instead to throw an exception. | |
/// It will run every tasks without throttling, if the maxConcurrency is less than 1. | |
/// </remarks> | |
/// <exception cref="ArgumentNullException">The tasks argument was null.</exception> | |
public static async Task<IEnumerable<TResult>> WhenAll<TResult>(this IEnumerable<Task<TResult>> tasks, int maxConcurrency) | |
{ | |
if (tasks == null) | |
{ | |
throw new ArgumentNullException(nameof(tasks)); | |
} | |
if (maxConcurrency < 1) | |
{ | |
return await Task.WhenAll(tasks.Where(t => t != null)); | |
} | |
var tasksToRun = new List<Task<TResult>>(); | |
using (var sem = new SemaphoreSlim(maxConcurrency)) | |
{ | |
var enumerator = tasks.Where(t => t != null).GetEnumerator(); | |
while (true) | |
{ | |
await sem.WaitAsync(); | |
if (enumerator.MoveNext()) | |
{ | |
var task = enumerator.Current.ContinueWith(t => | |
{ | |
sem.Release(); | |
return t.Result; | |
}); | |
tasksToRun.Add(task); | |
} | |
else | |
{ | |
break; | |
} | |
} | |
return (await Task.WhenAll(tasksToRun)).ToList(); | |
} | |
} | |
/// <summary> | |
/// Creates a task that will complete when all of the <see cref="Task"/> objects in an enumerable collection have completed. | |
/// It runs only specified number of tasks at most in parallel. | |
/// </summary> | |
/// <param name="tasks">The tasks to wait on for completion.</param> | |
/// <param name="maxConcurrency">The maximum number of tasks to run at the same time.</param> | |
/// <remarks> | |
/// It will skip a null in the tasks, instead to throw an exception. | |
/// It will run every tasks without throttling, if the maxConcurrency is less than 1. | |
/// </remarks> | |
/// <exception cref="ArgumentNullException">The tasks argument was null.</exception> | |
public static async Task WhenAll(this IEnumerable<Task> tasks, int maxConcurrency) | |
{ | |
if (tasks == null) | |
{ | |
throw new ArgumentNullException(nameof(tasks)); | |
} | |
if (maxConcurrency < 1) | |
{ | |
await Task.WhenAll(tasks.Where(t => t != null)); | |
return; | |
} | |
var tasksToRun = new List<Task>(); | |
using (var sem = new SemaphoreSlim(maxConcurrency)) | |
{ | |
var enumerator = tasks.Where(t => t != null).GetEnumerator(); | |
while (true) | |
{ | |
await sem.WaitAsync(); | |
if (enumerator.MoveNext()) | |
{ | |
var task = enumerator.Current.ContinueWith(t => | |
{ | |
sem.Release(); | |
}); | |
tasksToRun.Add(task); | |
} | |
else | |
{ | |
break; | |
} | |
} | |
await Task.WhenAll(tasksToRun); | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
using System.Linq; | |
using System.Threading.Tasks; | |
using Xunit; | |
namespace Gist.Tests | |
{ | |
public class TaskExtensionsTest | |
{ | |
[Theory] | |
[InlineData(1)] | |
[InlineData(5)] | |
[InlineData(10)] | |
public async Task WhenAll_runs_only_specific_number_of_tasks_at_a_time(int maxConcurrency) | |
{ | |
var running = 0; | |
var results = new List<int>(); | |
var tasks = Enumerable.Range(0, maxConcurrency * 5).Select(i => | |
{ | |
return Task.Run(() => { running++; }) | |
.ContinueWith(task => | |
{ | |
results.Add(running); | |
Task.Delay(15).Wait(); | |
results.Add(running); | |
}) | |
.ContinueWith(task => { running--; }); | |
}); | |
await tasks.WhenAll(maxConcurrency); | |
Assert.All(results, r => Assert.True(r <= maxConcurrency)); | |
} | |
[Theory] | |
[InlineData(1)] | |
[InlineData(5)] | |
[InlineData(10)] | |
public async Task WhenAll_with_result_runs_only_specific_number_of_tasks_at_a_time(int maxConcurrency) | |
{ | |
var running = 0; | |
var tasks = Enumerable.Range(0, maxConcurrency * 5).Select(i => | |
{ | |
return Task.Run(() => { running++; }) | |
.ContinueWith(task => | |
{ | |
Task.Delay(15).Wait(); | |
return running; | |
}) | |
.ContinueWith(task => { running--; return task.Result; }); | |
}); | |
var results = await tasks.WhenAll(maxConcurrency); | |
Assert.All(results, r => Assert.True(r <= maxConcurrency)); | |
} | |
[Fact] | |
public async Task WhenAll_ignores_null_task() | |
{ | |
var count = 0; | |
var tasks = new Task[] | |
{ | |
Task.Run(() => count++), | |
Task.Run(() => count++), | |
null }; | |
await tasks.WhenAll(3); | |
Assert.Equal(2, count); | |
} | |
[Fact] | |
public async Task WhenAll_with_result_ignores_null_task() | |
{ | |
var tasks = new Task<int>[] | |
{ | |
Task.Run(() => 1), | |
Task.Run(() => 1), | |
null }; | |
var results = await tasks.WhenAll(3); | |
Assert.All(results, r => Assert.Equal(1, r)); | |
} | |
[Fact] | |
public async Task WhenAll_has_guard_clause() | |
{ | |
Task[] tasks = null; | |
await Assert.ThrowsAsync<System.ArgumentNullException>( | |
() => tasks.WhenAll(3)); | |
} | |
[Fact] | |
public async Task WhenAll_with_result_has_guard_clause() | |
{ | |
Task<int>[] tasks = null; | |
await Assert.ThrowsAsync<System.ArgumentNullException>( | |
() => tasks.WhenAll(3)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment