Created
October 24, 2024 15:36
-
-
Save ruccho/76af66f48fbbe6168ff97a3b3591dea2 to your computer and use it in GitHub Desktop.
minimum Task-like implementation only for single-thread use
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Runtime.CompilerServices; | |
using System; | |
using System.Collections.Generic; | |
using System.Threading; | |
namespace STask | |
{ | |
[AsyncMethodBuilder(typeof(STaskMethodBuilder))] | |
public struct STaskVoid | |
{ | |
internal readonly STaskCompletionSource<byte> source; | |
internal STaskVoid(STaskCompletionSource<byte> source) | |
{ | |
this.source = source; | |
} | |
public STaskAwaiter GetAwaiter() | |
{ | |
return new STaskAwaiter(source); | |
} | |
} | |
[AsyncMethodBuilder(typeof(STaskMethodBuilder<>))] | |
public struct STask<TResult> | |
{ | |
internal readonly STaskCompletionSource<TResult> source; | |
internal STask(STaskCompletionSource<TResult> source) | |
{ | |
this.source = source; | |
} | |
public STaskAwaiter<TResult> GetAwaiter() | |
{ | |
return new STaskAwaiter<TResult>(source); | |
} | |
} | |
public readonly struct STaskAwaiter : ICriticalNotifyCompletion | |
{ | |
private readonly STaskCompletionSource<byte> source; | |
internal STaskAwaiter(STaskCompletionSource<byte> source) | |
{ | |
this.source = source; | |
} | |
public bool IsCompleted => source.IsCompleted; | |
public void GetResult() => source.GetResult(); | |
public void OnCompleted(Action continuation) | |
{ | |
source.OnCompleted(static ctx => ((Action)ctx).Invoke(), continuation); | |
} | |
public void UnsafeOnCompleted(Action continuation) | |
{ | |
source.OnCompleted(static ctx => ((Action)ctx).Invoke(), continuation); | |
} | |
} | |
public readonly struct STaskAwaiter<TResult> : ICriticalNotifyCompletion | |
{ | |
private readonly STaskCompletionSource<TResult> source; | |
internal STaskAwaiter(STaskCompletionSource<TResult> source) | |
{ | |
this.source = source; | |
} | |
public bool IsCompleted => source.IsCompleted; | |
public TResult GetResult() => source.GetResult(); | |
public void OnCompleted(Action continuation) | |
{ | |
source.OnCompleted(static ctx => ((Action)ctx).Invoke(), continuation); | |
} | |
public void UnsafeOnCompleted(Action continuation) | |
{ | |
source.OnCompleted(static ctx => ((Action)ctx).Invoke(), continuation); | |
} | |
} | |
internal readonly struct STaskCompletionSource<TResult> | |
{ | |
private readonly Core core; | |
private readonly ushort version; | |
public void SetResult(TResult result) => core.SetResult(version, result); | |
public void SetException(Exception ex) => core.SetException(version, ex); | |
public void OnCompleted(Action<object> continuation, object context) => core.OnCompleted(version, continuation, context); | |
public TResult GetResult() => core.GetResult(version); | |
public bool IsCompleted => core.IsCompleted; | |
public static STaskCompletionSource<TResult> Create() | |
{ | |
return new STaskCompletionSource<TResult>(new Core().Get()); | |
} | |
private STaskCompletionSource(Core core) | |
{ | |
this.core = core; | |
this.version = core.Version; | |
} | |
private class Core | |
{ | |
[ThreadStatic] private static Stack<Core> pool; | |
public Core Get() | |
{ | |
pool ??= new(); | |
if (!pool.TryPop(out var pooled)) pooled = new(); | |
pooled.thread = Thread.CurrentThread; | |
return pooled; | |
} | |
private Thread thread; | |
private ushort version; | |
private bool isCompleted; | |
private TResult result; | |
private Exception exception; | |
private Action<object> continuation; | |
private object continuationContext; | |
public ushort Version => version; | |
public bool IsCompleted => isCompleted; | |
private void ThrowIfDifferentThread() | |
{ | |
if (thread != Thread.CurrentThread) | |
throw new InvalidOperationException(); | |
} | |
private void ThrowIfOutdated(ushort version) | |
{ | |
if (this.version != version) | |
throw new InvalidOperationException(); | |
} | |
public void SetResult(ushort version, TResult result) | |
{ | |
ThrowIfDifferentThread(); | |
ThrowIfOutdated(version); | |
this.result = result; | |
isCompleted = true; | |
if (continuation != null) | |
{ | |
var c = continuation; | |
continuation = default; | |
c?.Invoke(continuationContext); | |
} | |
} | |
public void SetException(ushort version, Exception ex) | |
{ | |
ThrowIfDifferentThread(); | |
ThrowIfOutdated(version); | |
this.exception = ex; | |
isCompleted = true; | |
if (continuation != null) | |
{ | |
var c = continuation; | |
continuation = default; | |
c?.Invoke(continuationContext); | |
} | |
} | |
public void OnCompleted(ushort version, Action<object> continuation, object context) | |
{ | |
ThrowIfDifferentThread(); | |
ThrowIfOutdated(version); | |
if (this.continuation != null) | |
{ | |
throw new InvalidOperationException(); | |
} | |
if (isCompleted) | |
{ | |
continuation?.Invoke(context); | |
} | |
else | |
{ | |
this.continuation = continuation; | |
this.continuationContext = context; | |
} | |
} | |
public TResult GetResult(ushort version) | |
{ | |
ThrowIfDifferentThread(); | |
ThrowIfOutdated(version); | |
if (!isCompleted) throw new InvalidOperationException(); | |
try | |
{ | |
if (exception != null) | |
throw exception; | |
return result; | |
} | |
finally | |
{ | |
thread = default; | |
isCompleted = default; | |
result = default; | |
exception = default; | |
continuation = default; | |
continuationContext = default; | |
if (++this.version != ushort.MaxValue) | |
{ | |
pool ??= new(); | |
pool.Push(this); | |
} | |
} | |
} | |
} | |
} | |
public struct STaskMethodBuilder | |
{ | |
public static STaskMethodBuilder Create() => new() | |
{ | |
source = STaskCompletionSource<byte>.Create() | |
}; | |
private STaskCompletionSource<byte> source; | |
public void Start<TStateMachine>(ref TStateMachine stateMachine) | |
where TStateMachine : IAsyncStateMachine | |
{ | |
stateMachine.MoveNext(); | |
} | |
public void SetStateMachine(IAsyncStateMachine stateMachine) | |
{ | |
// nop | |
} | |
public void SetException(Exception exception) | |
{ | |
source.SetException(exception); | |
} | |
public void SetResult() | |
{ | |
source.SetResult(0); | |
} | |
public void AwaitOnCompleted<TAwaiter, TStateMachine>( | |
ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : INotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.OnCompleted(StateMachineContinuation<TStateMachine>.Get(ref stateMachine)); | |
} | |
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>( | |
ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : ICriticalNotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.UnsafeOnCompleted(StateMachineContinuation<TStateMachine>.Get(ref stateMachine)); | |
} | |
public STaskVoid Task => new STaskVoid(source); | |
} | |
public struct STaskMethodBuilder<T> | |
{ | |
public static STaskMethodBuilder<T> Create() => new STaskMethodBuilder<T>() | |
{ | |
source = STaskCompletionSource<T>.Create() | |
}; | |
private STaskCompletionSource<T> source; | |
public void Start<TStateMachine>(ref TStateMachine stateMachine) | |
where TStateMachine : IAsyncStateMachine | |
{ | |
stateMachine.MoveNext(); | |
} | |
public void SetStateMachine(IAsyncStateMachine stateMachine) | |
{ | |
// nop | |
} | |
public void SetException(Exception exception) | |
{ | |
source.SetException(exception); | |
} | |
public void SetResult(T result) | |
{ | |
source.SetResult(result); | |
} | |
public void AwaitOnCompleted<TAwaiter, TStateMachine>( | |
ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : INotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.OnCompleted(StateMachineContinuation<TStateMachine>.Get(ref stateMachine)); | |
} | |
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>( | |
ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : ICriticalNotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.UnsafeOnCompleted(StateMachineContinuation<TStateMachine>.Get(ref stateMachine)); | |
} | |
public STask<T> Task => new STask<T>(source); | |
} | |
internal class StateMachineContinuation<TStateMachine> where TStateMachine : IAsyncStateMachine | |
{ | |
[ThreadStatic] private static Stack<StateMachineContinuation<TStateMachine>> pool; | |
private TStateMachine stateMachine; | |
private Action continuation; | |
#if DEBUG | |
private System.Threading.Thread thread; | |
#endif | |
private StateMachineContinuation() | |
{ | |
this.continuation = Continue; | |
} | |
public static Action Get(ref TStateMachine stateMachine) | |
{ | |
pool ??= new(); | |
if (!pool.TryPop(out var pooled)) pooled = new(); | |
pooled.stateMachine = stateMachine; | |
#if DEBUG | |
pooled.thread = System.Threading.Thread.CurrentThread; | |
#endif | |
return pooled.continuation; | |
} | |
private void Continue() | |
{ | |
try | |
{ | |
#if DEBUG | |
if (System.Threading.Thread.CurrentThread != thread) | |
{ | |
throw new InvalidOperationException(); | |
} | |
#endif | |
stateMachine.MoveNext(); | |
} | |
finally | |
{ | |
stateMachine = default; | |
#if DEBUG | |
thread = default; | |
#endif | |
pool ??= new(); | |
pool.Push(this); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment