I have some problems with ReaderWriterLockSlim. I cannot understand how it's magic working.
My code:
private async Task LoadIndex()
{
if (!File.Exists(FileName + ".index.txt"))
{
return;
}
_indexLock.EnterWriteLock();// <1>
_index.Clear();
using (TextReader index = File.OpenText(FileName + ".index.txt"))
{
string s;
while (null != (s = await index.ReadLineAsync()))
{
var ss = s.Split(':');
_index.Add(ss[0], Convert.ToInt64(ss[1]));
}
}
_indexLock.ExitWriteLock();<2>
}
When I enter write lock at <1>, in debugger I can see that _indexLock.IsWriteLockHeld is true, but when execution steps to <2> I see _indexLock.IsWriteLockHeld is false
and _indexLock.ExitWriteLock throws an exception SynchronizationLockException with message "The write lock is being released without being held". What I doing wrong?
ReaderWriterLockSlim is a thread-affine lock type, so it usually cannot be used with async and await.
You should either use SemaphoreSlim with WaitAsync, or (if you really need a reader/writer lock), use my AsyncReaderWriterLock from AsyncEx or Stephen Toub's AsyncReaderWriterLock.
You can safely emulate a reader/writer locking mechanism using the reliable and lightweight SemaphoreSlim and keep the benefits of async/await. Create the SemaphoreSlim giving it the number of available locks equivalent to the number of routines that will lock your resource for reading simultaneously. Each one will request one lock as usual. For your writing routine, make sure it requests all the available locks before doing its thing.That way, your writing routine will always run alone while your reading routines might share the resource only between themselves.For example, suppose you have 2 reading routines and 1 writing routine.
SemaphoreSlim semaphore = new SemaphoreSlim(2);
async void Reader1()
{
await semaphore.WaitAsync();
try
{
// ... reading stuff ...
}
finally
{
semaphore.Release();
}
}
async void Reader2()
{
await semaphore.WaitAsync();
try
{
// ... reading other stuff ...
}
finally
{
semaphore.Release();
}
}
async void ExclusiveWriter()
{
// the exclusive writer must request all locks
// to make sure the readers don't have any of them
// (I wish we could specify the number of locks
// instead of spamming multiple calls!)
await semaphore.WaitAsync();
await semaphore.WaitAsync();
try
{
// ... writing stuff ...
}
finally
{
// release all locks here
semaphore.Release(2);
// (oh here we don't need multiple calls, how about that)
}
}
Obviously this method only works if you know beforehand how many reading routines you could have running at the same time. Admittedly, too much of them would make this code very ugly.
Some time ago I implemented for my project class AsyncReaderWriterLock based on two SemaphoreSlim. Hope it can help. It is implemented the same logic (Multiple Readers and Single Writer) and at the same time support async/await pattern. Definitely, it does not support recursion and has no protection from incorrect usage:
var rwLock = new AsyncReaderWriterLock();
await rwLock.AcquireReaderLock();
try
{
// ... reading ...
}
finally
{
rwLock.ReleaseReaderLock();
}
await rwLock.AcquireWriterLock();
try
{
// ... writing ...
}
finally
{
rwLock.ReleaseWriterLock();
}
public sealed class AsyncReaderWriterLock : IDisposable
{
private readonly SemaphoreSlim _readSemaphore = new SemaphoreSlim(1, 1);
private readonly SemaphoreSlim _writeSemaphore = new SemaphoreSlim(1, 1);
private int _readerCount;
public async Task AcquireWriterLock(CancellationToken token = default)
{
await _writeSemaphore.WaitAsync(token).ConfigureAwait(false);
await SafeAcquireReadSemaphore(token).ConfigureAwait(false);
}
public void ReleaseWriterLock()
{
_readSemaphore.Release();
_writeSemaphore.Release();
}
public async Task AcquireReaderLock(CancellationToken token = default)
{
await _writeSemaphore.WaitAsync(token).ConfigureAwait(false);
if (Interlocked.Increment(ref _readerCount) == 1)
{
try
{
await SafeAcquireReadSemaphore(token).ConfigureAwait(false);
}
catch
{
Interlocked.Decrement(ref _readerCount);
throw;
}
}
_writeSemaphore.Release();
}
public void ReleaseReaderLock()
{
if (Interlocked.Decrement(ref _readerCount) == 0)
{
_readSemaphore.Release();
}
}
private async Task SafeAcquireReadSemaphore(CancellationToken token)
{
try
{
await _readSemaphore.WaitAsync(token).ConfigureAwait(false);
}
catch
{
_writeSemaphore.Release();
throw;
}
}
public void Dispose()
{
_writeSemaphore.Dispose();
_readSemaphore.Dispose();
}
}
https://learn.microsoft.com/en-us/dotnet/api/system.threading.readerwriterlockslim?view=net-5.0
From source:
ReaderWriterLockSlim has managed thread affinity; that is, each Thread
object must make its own method calls to enter and exit lock modes. No
thread can change the mode of another thread.
So here expected behavour. Async / await does not guarantee continuation in the same thread, so you can catch exception when you enter in write lock in one thread and try to exit in other thread.
Better to use other lock mechanisms from other answers like SemaphoreSlim.
Like Stephen Cleary says, ReaderWriterLockSlim is a thread-affine lock type, so it usually cannot be used with async and await.
You have to build a mechanism to avoid readers and writers accessing shared data at the same time. This algorithm should follow a few rules.
When requesting a readerlock:
Is there a writerlock active?
Is there anything already queued? (I think it's nice to execute it in order of requests)
When requesting a writerlock:
Is there a writerlock active? (because writerlocks shouldn't execute parallel)
Are there any reader locks active?
Is there anything already queued?
If any of these creteria answers yes, the execution should be queued and executed later. You could use TaskCompletionSources to continue awaits.
When any reader or writer is done, You should evaluate the queue and continue execute items when possible.
For example (nuget): AsyncReaderWriterLock
I have been inspired by #xtadex's answer. However, it seems to me, one instance of SemaphoreSlim is enougth to implement the same behaviour:
internal sealed class ReadWriteSemaphore : IDisposable
{
private readonly SemaphoreSlim _semaphore;
private int _count;
public ReadWriteSemaphore()
{
_semaphore = new SemaphoreSlim(2, 2);
_count = 0;
}
public void Dispose() => _semaphore.Dispose();
public async Task<IDisposable> WaitReadAsync(CancellationToken cancellation = default)
{
await _semaphore.WaitAsync(cancellation);
if (Interlocked.Increment(ref _count) > 1)
_semaphore.Release();
return new Token(() =>
{
if (Interlocked.Decrement(ref _count) == 0)
_semaphore.Release();
});
}
public async Task<IDisposable> WaitWriteAsync(CancellationToken cancellation = default)
{
int count = 0;
try
{
while (await _semaphore.WaitAsync(Timeout.Infinite, cancellation) && ++count < 2)
continue;
}
catch (OperationCanceledException) when (count > 0)
{
_semaphore.Release(count);
throw;
}
return new Token(() => _semaphore.Release(count));
}
private sealed class Token : IDisposable
{
private readonly Action _action;
public Token(Action action)
{
_action = action;
}
public void Dispose() => _action.Invoke();
}
}
It allows you to synchronize access to resource with using:
using (await _semaphore.WaitReadAsync()) // _semaphore is an instance of ReadWriteSemaphore
{
// read something
}
And for writhing:
using (await _semaphore.WaitWriteAsync()) // _semaphore is an instance of ReadWriteSemaphore
{
// write something
}
Related
I have a multi-threaded codebase that manipulates a stateful object that cannot be safely used concurrently.
However, it's hard to prevent concurrent usage, and easy to make mistakes, especially with all the various (useful!) async methods and possibilities that modern C# provides.
I can and do wrap the object and detect when simultaneous (and thus incorrect) access occurs, but that's a fairly fragile check, making it unattractive as a runtime check. When 2 or more async methods use the shared object, the check won't trigger unless they happen to do so exactly concurrently - so often that means that potentially faulty code will only rarely fault.
Worse, since the object is stateful, there exist state-transition patterns that are broken even without temporally overlapping access. Consider 2 threads both doing something like: get-value-from-object; perform-slow-computation-on-value; write-value-to-object. Any concurrency sanity checks on the object clearly won't trigger if two async contexts happen to do such operations when they happen to interleave (read or write) access to the shared object such they each access happens in a short window and doesn't coincide with other accesses. Essentially, this is a form of transaction isolation violation, but then in C# code, not a DB.
I want to see such errors more reliably, i.e. have some kind of fail-fast pattern that ideally triggers the programmer making the buggy change to see the error right away, and not after a tricky debugging session in production.
.net has a call-stack-like concept for async; this is clearly visible in the debugger, and there appears to be some level of api support for (or at least interaction with) it in the form of ExecutionContext, which enables stuff like AsyncLocal<> if I understand things correctly.
However, I have not found a way to detect when the conceptual async call stack turns into a tree; i.e. when the stateful object is accessed from 2 diverging branches of that async context. If I could detect such a tree-like state context, I could fail fast whenever an access to the stateful object is both preceded and succeeded by an access on another branch.
Is there some way to detect such async call context forking? Alternatively, how can I best fail-fast when mutable state is accessed from multiple async contexts, even when such access is not reliably temporally overlapping?
Is there some way to detect such async call context forking?
Yes
If your class has a member representing its state (e.g. private bool _isProcessing;), you can check that member to detect concurrent calls.
You need to make sure you read and write that member from a safe context, like from a lock(_someMutex) block.
Next Steps
Depending on what you want to do when you detect concurrent calls, you can implement your class in different ways.
If you want to handle concurrent calls by queuing them, you might want to just wrap the body of all your method around an async lock. You don't need an _isProcessing member with this approach.
private SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private async Task<IDisposable> LockAsync()
{
await _semaphore.WaitAsync();
return new ActionDisposable(() => _semaphore.Release());
}
public async Task SomeAsyncMethod1()
{
using (await LockAsync())
{
// Do some async stuff.
}
}
public async Task SomeAsyncMethod2()
{
using (await LockAsync())
{
// Do some other async stuff.
}
}
If you want to discard concurrent calls, now the _isProcessing method is interesting.
private bool _isProcessing;
private object _mutex = new object();
public async Task SomeAsyncMethod1()
{
lock (_mutex)
{
if (_isProcessing)
{
// Concurrent call detected.
// Deal with it the way you want. (return null, throw an exception, whatever.)
throw new InvalidOperationException("Concurrent calls are not supported.");
}
_isProcessing = true;
}
// Do some async stuff.
_isProcessing = false;
}
public async Task SomeAsyncMethod2()
{
lock (_mutex)
{
if (_isProcessing)
{
// Concurrent call detected.
// Deal with it the way you want. (return null, throw an exception, whatever.)
throw new InvalidOperationException("Concurrent calls are not supported.");
}
_isProcessing = true;
}
// Do some other async stuff.
_isProcessing = false;
}
Maybe something like this can help? In your wrapper object you'd need to check AsyncLock.CanProceed, if it is false the caller didn't obtain a lock.
It would be possible to extend the code below to block while waiting for a lock.
using System;
using System.Threading;
using System.Diagnostics;
using System.Threading.Tasks;
public class Program {
static Random _rnd = new();
public static async Task Main() {
for (int i = 0; i < 9; i++) {
var t1 = TestTask("1");
var t2 = TestTask("2");
var t3 = SubTask("3");
await Task.WhenAll(t1, t2, t3);
}
}
static async Task TestTask(string id) {
try {
await Task.Delay(_rnd.Next(50,250));
using (var alock = await AsyncLock.RequestLock()) {
Console.WriteLine($"TASK {id} {alock is not null} {AsyncLock.CanProceed}");
await Task.Delay(_rnd.Next(50,250));
await SubTask(id);
}
} catch {
Console.WriteLine($"LOCK FAILED {id}");
}
}
static async Task SubTask(string id) {
try {
await Task.Delay(_rnd.Next(50,250));
if (AsyncLock.CanProceed) {
Console.WriteLine($"TASK {id} SUB CAN PROCEED");
return;
} else {
using (var alock = await AsyncLock.RequestLock()) {
Console.WriteLine($"TASK {id} SUB {alock is not null} {AsyncLock.CanProceed}");
await Task.Delay(50);
}
}
} catch {
Console.WriteLine($"LOCK FAILED {id} SUB");
}
}
}
public class AsyncLock: IDisposable {
private static AsyncLock _lockGlobal;
private static AsyncLocal<AsyncLock> _lockAsync = new();
private static object _mutex = new();
public bool IsDisposed { get; private set; }
public static Task<AsyncLock> RequestLock() => RequestLock(new CancellationTokenSource().Token);
public static Task<AsyncLock> RequestLock(CancellationToken cancellationToken) {
lock (_mutex) {
if (_lockGlobal is not null) {
Console.WriteLine("BLOCK");
var tcs = new CancellationTokenSource();
tcs.Cancel();
return Task.FromCanceled<AsyncLock>(tcs.Token);
}
Console.WriteLine("LOCK");
return Task.FromResult(_lockAsync.Value = _lockGlobal = new AsyncLock());
}
}
public static bool CanProceed => _lockAsync.Value is not null;
private AsyncLock() {}
public void Dispose() {
lock (_mutex) {
if (IsDisposed) return;
Console.WriteLine("DISPOSE");
Debug.Assert(_lockGlobal == this);
Debug.Assert(_lockAsync.Value == this);
IsDisposed = true;
_lockAsync.Value = _lockGlobal = null;
}
}
}
If I have a task running on a worker thread and when it finds something wrong, is it possible to pause and wait for the user to intervene before continuing?
For example, suppose I have something like this:
async void btnStartTask_Click(object sender, EventArgs e)
{
await Task.Run(() => LongRunningTask());
}
// CPU-bound
bool LongRunningTask()
{
// Establish some connection here.
// Do some work here.
List<Foo> incorrectValues = GetIncorrectValuesFromAbove();
if (incorrectValues.Count > 0)
{
// Here, I want to present the "incorrect values" to the user (on the UI thread)
// and let them select whether to modify a value, ignore it, or abort.
var confirmedValues = WaitForUserInput(incorrectValues);
}
// Continue processing.
}
Is it possible to substitute WaitForUserInput() with something that runs on the UI thread, waits for the user's intervention, and then acts accordingly? If so, how? I'm not looking for complete code or anything; if someone could point me in the right direction, I would be grateful.
What you're looking for is almost exactly Progress<T>, except you want to have the thing that reports progress get a task back with some information that they can await and inspect the results of. Creating Progress<T> yourself isn't terribly hard., and you can reasonably easily adapt it so that it computes a result.
public interface IPrompt<TResult, TInput>
{
Task<TResult> Prompt(TInput input);
}
public class Prompt<TResult, TInput> : IPrompt<TResult, TInput>
{
private SynchronizationContext context;
private Func<TInput, Task<TResult>> prompt;
public Prompt(Func<TInput, Task<TResult>> prompt)
{
context = SynchronizationContext.Current ?? new SynchronizationContext();
this.prompt += prompt;
}
Task<TResult> IPrompt<TResult, TInput>.Prompt(TInput input)
{
var tcs = new TaskCompletionSource<TResult>();
context.Post(data => prompt((TInput)data)
.ContinueWith(task =>
{
if (task.IsCanceled)
tcs.TrySetCanceled();
if (task.IsFaulted)
tcs.TrySetException(task.Exception.InnerExceptions);
else
tcs.TrySetResult(task.Result);
}), input);
return tcs.Task;
}
}
Now you simply need to have an asynchronous method that accepts the data from the long running process and returns a task with whatever the user interface's response is.
You can use TaskCompletionSource to generate a task that can be awaited within the LongRunningTask.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
namespace ConsoleApp5
{
class Program
{
private static event Action<string> Input;
public static async Task Main(string[] args)
{
var inputTask = InputTask();
var longRunningTask = Task.Run(() => LongRunningTask());
await Task.WhenAll(inputTask, longRunningTask);
}
private static async Task InputTask()
{
await Task.Yield();
while(true)
{
var input = await Console.In.ReadLineAsync();
Input?.Invoke(input);
}
}
static async Task<bool> LongRunningTask()
{
SomeExpensiveCall();
var incorrectValues = GetIncorrectValuesFromAbove();
if (incorrectValues.Count > 0)
{
var confirmedValues = await WaitForUserInput(incorrectValues).ConfigureAwait(false);
}
// Continue processing.
return true;
}
private static void SomeExpensiveCall()
{
}
private static Task<string> WaitForUserInput(IList<string> incorrectValues)
{
var taskCompletionSource = new TaskCompletionSource<string>();
Console.Write("Input Data: ");
try
{
void EventHandler(string input)
{
Input -= EventHandler;
taskCompletionSource.TrySetResult(input);
}
Input += EventHandler;
}
catch(Exception e)
{
taskCompletionSource.TrySetException(e);
}
return taskCompletionSource.Task;
}
private static IList<string> GetIncorrectValuesFromAbove()
{
return new List<string> { "Test" };
}
}
}
Of course in this example you could have just called await Console.In.ReadLineAsync() directly, but this code is to simulate an environment where you only have an event based API.
There are several ways to solve this problem, with the Control.Invoke being probably the most familiar. Here is a more TPL-ish approach. You start by declaring a UI related scheduler as a class field:
private TaskScheduler _uiScheduler;
Then initialize it:
public MyForm()
{
InitializeComponent();
_uiScheduler = TaskScheduler.FromCurrentSynchronizationContext();
}
Then you convert your synchronous LongRunning method to an asynchronous method. This means that it must return Task<bool> instead of bool. It must also have the async modifier, and by convention be named with the Async suffix:
async Task<bool> LongRunningAsync()
Finally you use the await operator in order to wait for the user's input, which will be a Task configured to run on the captured UI scheduler:
async Task<bool> LongRunningAsync()
{
// Establish some connection here.
// Do some work here.
List<Foo> incorrectValues = GetIncorrectValuesFromAbove();
if (incorrectValues.Count > 0)
{
// Here, I want to present the "incorrect values" to the user (on the UI thread)
// and let them select whether to modify a value, ignore it, or abort.
var confirmedValues = await Task.Factory.StartNew(() =>
{
return WaitForUserInput(incorrectValues);
}, default, TaskCreationOptions.None, _uiScheduler);
}
// Continue processing.
}
Starting the long running task is the same as before. The Task.Run understands async delegates, so you don't have to do something special after making the method async.
var longRunningTask = Task.Run(() => LongRunningAsync());
This should be enough, provided that you just intend to show a dialog box to the user. The Form.ShowDialog is a blocking method, so the WaitForUserInput method needs not to be asynchronous. If you had to allow the user to interact freely with the main form, the problem would be much more difficult to solve.
Another example using Invoke() and a ManualResetEvent. Let me know if you need help with the form code; setting up a constructor, using DialogResult, or creating a property to hold the "confirmedValues":
bool LongRunningTask()
{
// Establish some connection here.
// Do some work here.
List<Foo> incorrectValues = GetIncorrectValuesFromAbove();
var confirmedValues;
if (incorrectValues.Count > 0)
{
DialogResult result;
ManualResetEvent mre = new ManualResetEvent(false);
this.Invoke((MethodInvoker)delegate
{
// pass in incorrectValues to the form
// you'll have to build a constructor in it to accept them
frmSomeForm frm = new frmSomeForm(incorrectValues);
result = frm.ShowDialog();
if (result == DialogResult.OK)
{
confirmedValues = frm.confirmedValues; // get the confirmed values somehow
}
mre.Set(); // release the block below
});
mre.WaitOne(); // blocks until "mre" is set
}
// Continue processing.
}
I've got this pattern for preventing calling into an async method before it has had a chance to complete previously.
My solution involving needing a flag, and then needing to lock around the flag, feels pretty verbose. Is there a more natural way of achieving this?
public class MyClass
{
private object SyncIsFooRunning = new object();
private bool IsFooRunning { get; set;}
public async Task FooAsync()
{
try
{
lock(SyncIsFooRunning)
{
if(IsFooRunning)
return;
IsFooRunning = true;
}
// Use a semaphore to enforce maximum number of Tasks which are able to run concurrently.
var semaphoreSlim = new SemaphoreSlim(5);
var trackedTasks = new List<Task>();
for(int i = 0; i < 100; i++)
{
await semaphoreSlim.WaitAsync();
trackedTasks.Add(Task.Run(() =>
{
// DoTask();
semaphoreSlim.Release();
}));
}
// Using await makes try/catch/finally possible.
await Task.WhenAll(trackedTasks);
}
finally
{
lock(SyncIsFooRunning)
{
IsFooRunning = false;
}
}
}
}
As noted in the comments, you can use Interlocked.CompareExchange() if you prefer:
public class MyClass
{
private int _flag;
public async Task FooAsync()
{
try
{
if (Interlocked.CompareExchange(ref _flag, 1, 0) == 1)
{
return;
}
// do stuff
}
finally
{
Interlocked.Exchange(ref _flag, 0);
}
}
}
That said, I think it's overkill. Nothing wrong with using lock in this type of scenario, especially if you don't expect a lot of contention on the method. What I do think would be better is to wrap the method so that the caller can always await on the result, whether a new asynchronous operation was started or not:
public class MyClass
{
private readonly object _lock = new object();
private Task _task;
public Task FooAsync()
{
lock (_lock)
{
return _task != null ? _task : (_task = FooAsyncImpl());
}
}
public async Task FooAsyncImpl()
{
try
{
// do async stuff
}
finally
{
lock (_lock) _task = null;
}
}
}
Finally, in the comments, you say this:
Seems a bit odd that all the return types are still valid for Task?
Not clear to me what you mean by that. In your method, the only valid return types would be void and Task. If your return statement(s) returned an actual value, you'd have to use Task<T> where T is the type returned by the return statement(s).
I would like to have a custom thread pool satisfying the following requirements:
Real threads are preallocated according to the pool capacity. The actual work is free to use the standard .NET thread pool, if needed to spawn concurrent tasks.
The pool must be able to return the number of idle threads. The returned number may be less than the actual number of the idle threads, but it must not be greater. Of course, the more accurate the number the better.
Queuing work to the pool should return a corresponding Task, which should place nice with the Task based API.
NEW The max job capacity (or degree of parallelism) should be adjustable dynamically. Trying to reduce the capacity does not have to take effect immediately, but increasing it should do so immediately.
The rationale for the first item is depicted below:
The machine is not supposed to be running more than N work items concurrently, where N is relatively small - between 10 and 30.
The work is fetched from the database and if K items are fetched then we want to make sure that there are K idle threads to start the work right away. A situation where work is fetched from the database, but remains waiting for the next available thread is unacceptable.
The last item also explains the reason for having the idle thread count - I am going to fetch that many work items from the database. It also explains why the reported idle thread count must never be higher than the actual one - otherwise I might fetch more work that can be immediately started.
Anyway, here is my implementation along with a small program to test it (BJE stands for Background Job Engine):
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public class BJEThreadPool
{
private sealed class InternalTaskScheduler : TaskScheduler
{
private int m_idleThreadCount;
private readonly BlockingCollection<Task> m_bus;
public InternalTaskScheduler(int threadCount, BlockingCollection<Task> bus)
{
m_idleThreadCount = threadCount;
m_bus = bus;
}
public void RunInline(Task task)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
TryExecuteTask(task);
}
catch
{
// The action is responsible itself for the error handling, for the time being...
}
Interlocked.Increment(ref m_idleThreadCount);
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#region Overrides of TaskScheduler
protected override void QueueTask(Task task)
{
m_bus.Add(task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
return TryExecuteTask(task);
}
protected override IEnumerable<Task> GetScheduledTasks()
{
throw new NotSupportedException();
}
#endregion
public void DecrementIdleThreadCount()
{
Interlocked.Decrement(ref m_idleThreadCount);
}
}
private class ThreadContext
{
private readonly InternalTaskScheduler m_ts;
private readonly BlockingCollection<Task> m_bus;
private readonly CancellationTokenSource m_cts;
public readonly Thread Thread;
public ThreadContext(string name, InternalTaskScheduler ts, BlockingCollection<Task> bus, CancellationTokenSource cts)
{
m_ts = ts;
m_bus = bus;
m_cts = cts;
Thread = new Thread(Start)
{
IsBackground = true,
Name = name
};
Thread.Start();
}
private void Start()
{
try
{
foreach (var task in m_bus.GetConsumingEnumerable(m_cts.Token))
{
m_ts.RunInline(task);
}
}
catch (OperationCanceledException)
{
}
m_ts.DecrementIdleThreadCount();
}
}
private readonly InternalTaskScheduler m_ts;
private readonly CancellationTokenSource m_cts = new CancellationTokenSource();
private readonly BlockingCollection<Task> m_bus = new BlockingCollection<Task>();
private readonly List<ThreadContext> m_threadCtxs = new List<ThreadContext>();
public BJEThreadPool(int threadCount)
{
m_ts = new InternalTaskScheduler(threadCount, m_bus);
for (int i = 0; i < threadCount; ++i)
{
m_threadCtxs.Add(new ThreadContext("BJE Thread " + i, m_ts, m_bus, m_cts));
}
}
public void Terminate()
{
m_cts.Cancel();
foreach (var t in m_threadCtxs)
{
t.Thread.Join();
}
}
public Task Run(Action<CancellationToken> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public Task Run(Action action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public int IdleThreadCount
{
get { return m_ts.IdleThreadCount; }
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 32;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(cancellationToken =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
tcs.Task.Wait(cancellationToken);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("Cancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("Error All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("Abort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}
It is a very simple implementation and it appears to be working. However, there is a problem - the BJEThreadPool.Run method does not accept asynchronous methods. I.e. my implementation does not allow me to add the following overloads:
public Task Run(Func<CancellationToken, Task> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
public Task Run(Func<Task> action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
The pattern I use in InternalTaskScheduler.RunInline does not work in this case.
So, my question is how to add the support for asynchronous work items? I am fine with changing the entire design as long as the requirements outlined at the beginning of the post are upheld.
EDIT
I would like to clarify the intented usage of the desired pool. Please, observe the following code:
if (pool.IdleThreadCount == 0)
{
return;
}
foreach (var jobData in FetchFromDB(pool.IdleThreadCount))
{
pool.Run(CreateJobAction(jobData));
}
Notes:
The code is going to be run periodically, say every 1 minute.
The code is going to be run concurrently by multiple machines watching the same database.
FetchFromDB is going to use the technique described in Using SQL Server as a DB queue with multiple clients to atomically fetch and lock the work from the DB.
CreateJobAction is going to invoke the code denoted by jobData (the job code) and close the work upon the completion of that code. The job code is out of my control and it could be pretty much anything - heavy CPU bound code or light asynchronous IO bound code, badly written synchronous IO bound code or a mix of all. It could run for minutes and it could run for hours. Closing the work is my code and it would by asynchronous IO bound code. Because of this, the signature of the returned job action is that of an asynchronous method.
Item 2 underlines the importance of correctly identifying the amount of idle threads. If there are 900 pending work items and 10 agent machines I cannot allow an agent to fetch 300 work items and queue them on the thread pool. Why? Because, it is most unlikely that the agent will be able to run 300 work items concurrently. It will run some, sure enough, but others will be waiting in the thread pool work queue. Suppose it will run 100 and let 200 wait (even though 100 is probably far fetched). This wields 3 fully loaded agents and 7 idle ones. But only 300 work items out of 900 are actually being processed concurrently!!!
My goal is to maximize the spread of the work amongst the available agents. Ideally, I should evaluate the load of an agent and the "heaviness" of the pending work, but it is a formidable task and is reserved for the future versions. Right now, I wish to assign each agent the max job capacity with the intention to provide the means to increase/decrease it dynamically without restarting the agents.
Next observation. The work can take quite a long time to run and it could be all synchronous code. As far as I understand it is undesirable to utilize thread pool threads for such kind of work.
EDIT 2
There is a statement that TaskScheduler is only for the CPU bound work. But what if I do not know the nature of the work? I mean it is a general purpose Background Job Engine and it runs thousands of different kinds of jobs. I do not have means to tell "that job is CPU bound" and "that on is synchronous IO bound" and yet another one is asynchronous IO bound. I wish I could, but I cannot.
EDIT 3
At the end, I do not use the SemaphoreSlim, but neither do I use the TaskScheduler - it finally trickled down my thick skull that it is unappropriate and plain wrong, plus it makes the code overly complex.
Still, I failed to see how SemaphoreSlim is the way. The proposed pattern:
public async Task Enqueue(Func<Task> taskGenerator)
{
await semaphore.WaitAsync();
try
{
await taskGenerator();
}
finally
{
semaphore.Release();
}
}
Expects taskGenerator either be an asynchronous IO bound code or open a new thread otherwise. However, I have no means to determine whether the work to be executed is one or another. Plus, as I have learned from SemaphoreSlim.WaitAsync continuation code if the semaphore is unlocked, the code following the WaitAsync() is going to run on the same thread, which is not very good for me.
Anyway, below is my implementation, in case anyone fancies. Unfortunately, I am yet to understand how to reduce the pool thread count dynamically, but this is a topic for another question.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public interface IBJEThreadPool
{
void SetThreadCount(int threadCount);
void Terminate();
Task Run(Action action);
Task Run(Action<CancellationToken> action);
Task Run(Func<Task> action);
Task Run(Func<CancellationToken, Task> action);
int IdleThreadCount { get; }
}
public class BJEThreadPool : IBJEThreadPool
{
private interface IActionContext
{
Task Run(CancellationToken ct);
TaskCompletionSource<object> TaskCompletionSource { get; }
}
private class ActionContext : IActionContext
{
private readonly Action m_action;
public ActionContext(Action action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action();
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class CancellableActionContext : IActionContext
{
private readonly Action<CancellationToken> m_action;
public CancellableActionContext(Action<CancellationToken> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action(ct);
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncActionContext : IActionContext
{
private readonly Func<Task> m_action;
public AsyncActionContext(Func<Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action();
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncCancellableActionContext : IActionContext
{
private readonly Func<CancellationToken, Task> m_action;
public AsyncCancellableActionContext(Func<CancellationToken, Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action(ct);
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private readonly CancellationTokenSource m_ctsTerminateAll = new CancellationTokenSource();
private readonly BlockingCollection<IActionContext> m_bus = new BlockingCollection<IActionContext>();
private readonly LinkedList<Thread> m_threads = new LinkedList<Thread>();
private int m_idleThreadCount;
private static int s_threadCount;
public BJEThreadPool(int threadCount)
{
ReserveAdditionalThreads(threadCount);
}
private void ReserveAdditionalThreads(int n)
{
for (int i = 0; i < n; ++i)
{
var index = Interlocked.Increment(ref s_threadCount) - 1;
var t = new Thread(Start)
{
IsBackground = true,
Name = "BJE Thread " + index
};
Interlocked.Increment(ref m_idleThreadCount);
t.Start();
m_threads.AddLast(t);
}
}
private void Start()
{
try
{
foreach (var actionContext in m_bus.GetConsumingEnumerable(m_ctsTerminateAll.Token))
{
RunWork(actionContext).Wait();
}
}
catch (OperationCanceledException)
{
}
catch
{
// Should never happen - log the error
}
Interlocked.Decrement(ref m_idleThreadCount);
}
private async Task RunWork(IActionContext actionContext)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
var task = actionContext.Run(m_ctsTerminateAll.Token);
if (task != null)
{
await task;
}
actionContext.TaskCompletionSource.SetResult(null);
}
catch (OperationCanceledException)
{
actionContext.TaskCompletionSource.TrySetCanceled();
}
catch (Exception exc)
{
actionContext.TaskCompletionSource.TrySetException(exc);
}
Interlocked.Increment(ref m_idleThreadCount);
}
private Task PostWork(IActionContext actionContext)
{
m_bus.Add(actionContext);
return actionContext.TaskCompletionSource.Task;
}
#region Implementation of IBJEThreadPool
public void SetThreadCount(int threadCount)
{
if (threadCount > m_threads.Count)
{
ReserveAdditionalThreads(threadCount - m_threads.Count);
}
else if (threadCount < m_threads.Count)
{
throw new NotSupportedException();
}
}
public void Terminate()
{
m_ctsTerminateAll.Cancel();
foreach (var t in m_threads)
{
t.Join();
}
}
public Task Run(Action action)
{
return PostWork(new ActionContext(action));
}
public Task Run(Action<CancellationToken> action)
{
return PostWork(new CancellableActionContext(action));
}
public Task Run(Func<Task> action)
{
return PostWork(new AsyncActionContext(action));
}
public Task Run(Func<CancellationToken, Task> action)
{
return PostWork(new AsyncCancellableActionContext(action));
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#endregion
}
public static class Extensions
{
public static Task WithCancellation(this Task task, CancellationToken token)
{
return task.ContinueWith(t => t.GetAwaiter().GetResult(), token);
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 16;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(async ct =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
await tcs.Task.WithCancellation(ct);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("ancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("rror All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("bort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}
Asynchronous "work items" are often based on async IO. Async IO does not use threads while it runs. Task schedulers are used to execute CPU work (tasks based on a delegate). The concept TaskScheduler does not apply. You cannot use a custom TaskScheduler to influence what async code does.
Make your work items throttle themselves:
static SemaphoreSlim sem = new SemaphoreSlim(maxDegreeOfParallelism); //shared object
async Task MyWorkerFunction()
{
await sem.WaitAsync();
try
{
MyWork();
}
finally
{
sem.Release();
}
}
As mentioned in another answer by usr you can't do this with a TaskScheduler as that is only for CPU bound work, not limiting the level of parallelization of all types of work, whether parallel or not. He also shows you how you can use a SemaphoreSlim to asynchronously limit the degrees of parallelism.
You can expand on this to generalize these concepts in a few ways. The one that seems like it would be the most beneficial to you would be to create a special type of queue that takes operations that return a Task and executes them in such a way that a given max degree of parallelization is achieved.
public class FixedParallelismQueue
{
private SemaphoreSlim semaphore;
public FixedParallelismQueue(int maxDegreesOfParallelism)
{
semaphore = new SemaphoreSlim(maxDegreesOfParallelism,
maxDegreesOfParallelism);
}
public async Task<T> Enqueue<T>(Func<Task<T>> taskGenerator)
{
await semaphore.WaitAsync();
try
{
return await taskGenerator();
}
finally
{
semaphore.Release();
}
}
public async Task Enqueue(Func<Task> taskGenerator)
{
await semaphore.WaitAsync();
try
{
await taskGenerator();
}
finally
{
semaphore.Release();
}
}
}
This allows you to create a queue for your application (you can even have several separate queues if you want) that has a fixed degree of parallelization. You can then provide operations returning a Task when they complete and the queue will schedule it when it can and return a Task representing when that unit of work has finished.
I use C# iterators as a replacement for coroutines, and it has been working great. I want to switch to async/await as I think the syntax is cleaner and it gives me type safety.
In this (outdated) blog post, Jon Skeet shows a possible way to implement it.
I chose to go a slightly different way (by implementing my own SynchronizationContext and using Task.Yield). This worked fine.
Then I realized there would be a problem; currently a coroutine doesn't have to finish running. It can be stopped gracefully at any point where it yields. We might have code like this:
private IEnumerator Sleep(int milliseconds)
{
Stopwatch timer = Stopwatch.StartNew();
do
{
yield return null;
}
while (timer.ElapsedMilliseconds < milliseconds);
}
private IEnumerator CoroutineMain()
{
try
{
// Do something that runs over several frames
yield return Coroutine.Sleep(5000);
}
finally
{
Log("Coroutine finished, either after 5 seconds, or because it was stopped");
}
}
The coroutine works by keeping track of all enumerators in a stack. The C# compiler generates a Dispose function which can be called to ensure that the 'finally' block is correctly invoked in CoroutineMain, even if the enumeration isn't finished. This way we can stop a coroutine gracefully, and still ensure finally blocks are invoked, by calling Dispose on all the IEnumerator objects on the stack. This is basically manually unwinding.
When I wrote my implementation with async/await I realized that we would lose this feature, unless I'm mistaken. I then looked up other coroutine solutions, and it doesn't look like Jon Skeet's version handles it in any way either.
The only way I can think of to handle this would be to have our own custom 'Yield' function, which would check if the coroutine was stopped, and then raise an exception that indicated this. This would propagate up, executing finally blocks, and then be caught somewhere near the root. I don't find this pretty though, as 3rd party code could potentially catch the exception.
Am I misunderstanding something, and is this possible to do in an easier way? Or do I need to go the exception way to do this?
EDIT: More information/code has been requested, so here's some. I can guarantee this is going to be running on only a single thread, so there's no threading involved here.
Our current coroutine implementation looks a bit like this (this is simplified, but it works in this simple case):
public sealed class Coroutine : IDisposable
{
private class RoutineState
{
public RoutineState(IEnumerator enumerator)
{
Enumerator = enumerator;
}
public IEnumerator Enumerator { get; private set; }
}
private readonly Stack<RoutineState> _enumStack = new Stack<RoutineState>();
public Coroutine(IEnumerator enumerator)
{
_enumStack.Push(new RoutineState(enumerator));
}
public bool IsDisposed { get; private set; }
public void Dispose()
{
if (IsDisposed)
return;
while (_enumStack.Count > 0)
{
DisposeEnumerator(_enumStack.Pop().Enumerator);
}
IsDisposed = true;
}
public bool Resume()
{
while (true)
{
RoutineState top = _enumStack.Peek();
bool movedNext;
try
{
movedNext = top.Enumerator.MoveNext();
}
catch (Exception ex)
{
// Handle exception thrown by coroutine
throw;
}
if (!movedNext)
{
// We finished this (sub-)routine, so remove it from the stack
_enumStack.Pop();
// Clean up..
DisposeEnumerator(top.Enumerator);
if (_enumStack.Count <= 0)
{
// This was the outer routine, so coroutine is finished.
return false;
}
// Go back and execute the parent.
continue;
}
// We executed a step in this coroutine. Check if a subroutine is supposed to run..
object value = top.Enumerator.Current;
IEnumerator newEnum = value as IEnumerator;
if (newEnum != null)
{
// Our current enumerator yielded a new enumerator, which is a subroutine.
// Push our new subroutine and run the first iteration immediately
RoutineState newState = new RoutineState(newEnum);
_enumStack.Push(newState);
continue;
}
// An actual result was yielded, so we've completed an iteration/step.
return true;
}
}
private static void DisposeEnumerator(IEnumerator enumerator)
{
IDisposable disposable = enumerator as IDisposable;
if (disposable != null)
disposable.Dispose();
}
}
Assume we have code like the following:
private IEnumerator MoveToPlayer()
{
try
{
while (!AtPlayer())
{
yield return Sleep(500); // Move towards player twice every second
CalculatePosition();
}
}
finally
{
Log("MoveTo Finally");
}
}
private IEnumerator OrbLogic()
{
try
{
yield return MoveToPlayer();
yield return MakeExplosion();
}
finally
{
Log("OrbLogic Finally");
}
}
This would be created by passing an instance of the OrbLogic enumerator to a Coroutine, and then running it. This allows us to tick the coroutine every frame. If the player kills the orb, the coroutine doesn't finish running; Dispose is simply called on the coroutine. If MoveTo was logically in the 'try' block, then calling Dispose on the top IEnumerator will, semantically, make the finally block in MoveTo execute. Then afterwards the finally block in OrbLogic will execute.
Note that this is a simple case and the cases are much more complex.
I am struggling to implement similar behavior in the async/await version. The code for this version looks like this (error checking omitted):
public class Coroutine
{
private readonly CoroutineSynchronizationContext _syncContext = new CoroutineSynchronizationContext();
public Coroutine(Action action)
{
if (action == null)
throw new ArgumentNullException("action");
_syncContext.Next = new CoroutineSynchronizationContext.Continuation(state => action(), null);
}
public bool IsFinished { get { return !_syncContext.Next.HasValue; } }
public void Tick()
{
if (IsFinished)
throw new InvalidOperationException("Cannot resume Coroutine that has finished");
SynchronizationContext curContext = SynchronizationContext.Current;
try
{
SynchronizationContext.SetSynchronizationContext(_syncContext);
// Next is guaranteed to have value because of the IsFinished check
Debug.Assert(_syncContext.Next.HasValue);
// Invoke next continuation
var next = _syncContext.Next.Value;
_syncContext.Next = null;
next.Invoke();
}
finally
{
SynchronizationContext.SetSynchronizationContext(curContext);
}
}
}
public class CoroutineSynchronizationContext : SynchronizationContext
{
internal struct Continuation
{
public Continuation(SendOrPostCallback callback, object state)
{
Callback = callback;
State = state;
}
public SendOrPostCallback Callback;
public object State;
public void Invoke()
{
Callback(State);
}
}
internal Continuation? Next { get; set; }
public override void Post(SendOrPostCallback callback, object state)
{
if (callback == null)
throw new ArgumentNullException("callback");
if (Current != this)
throw new InvalidOperationException("Cannot Post to CoroutineSynchronizationContext from different thread!");
Next = new Continuation(callback, state);
}
public override void Send(SendOrPostCallback d, object state)
{
throw new NotSupportedException();
}
public override int Wait(IntPtr[] waitHandles, bool waitAll, int millisecondsTimeout)
{
throw new NotSupportedException();
}
public override SynchronizationContext CreateCopy()
{
throw new NotSupportedException();
}
}
I don't see how to implement similar behavior to the iterator version using this.
Apologies in advance for the lengthy code!
EDIT 2: The new method seems to be working. It allows me to do stuff like:
private static async Task Test()
{
// Second resume
await Sleep(1000);
// Unknown how many resumes
}
private static async Task Main()
{
// First resume
await Coroutine.Yield();
// Second resume
await Test();
}
Which provides a very nice way of building AI for games.
Updated, a follow-up blog post:
Asynchronous coroutines with C# 8.0 and IAsyncEnumerable.
I use C# iterators as a replacement for coroutines, and it has been
working great. I want to switch to async/await as I think the syntax
is cleaner and it gives me type safety...
IMO, it's a very interesting question, although it took me awhile to fully understand it. Perhaps, you didn't provide enough sample code to illustrate the concept. A complete app would help, so I'll try to fill this gap first. The following code illustrates the usage pattern as I understood it, please correct me if I'm wrong:
using System;
using System.Collections;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApplication
{
// https://stackoverflow.com/q/22852251/1768303
public class Program
{
class Resource : IDisposable
{
public void Dispose()
{
Console.WriteLine("Resource.Dispose");
}
~Resource()
{
Console.WriteLine("~Resource");
}
}
private IEnumerator Sleep(int milliseconds)
{
using (var resource = new Resource())
{
Stopwatch timer = Stopwatch.StartNew();
do
{
yield return null;
}
while (timer.ElapsedMilliseconds < milliseconds);
}
}
void EnumeratorTest()
{
var enumerator = Sleep(100);
enumerator.MoveNext();
Thread.Sleep(500);
//while (e.MoveNext());
((IDisposable)enumerator).Dispose();
}
public static void Main(string[] args)
{
new Program().EnumeratorTest();
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true);
GC.WaitForPendingFinalizers();
Console.ReadLine();
}
}
}
Here, Resource.Dispose gets called because of ((IDisposable)enumerator).Dispose(). If we don't call enumerator.Dispose(), then we'll have to uncomment //while (e.MoveNext()); and let the iterator finish gracefully, for proper unwinding.
Now, I think the best way to implement this with async/await is to use a custom awaiter:
using System;
using System.Collections;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApplication
{
// https://stackoverflow.com/q/22852251/1768303
public class Program
{
class Resource : IDisposable
{
public void Dispose()
{
Console.WriteLine("Resource.Dispose");
}
~Resource()
{
Console.WriteLine("~Resource");
}
}
async Task SleepAsync(int milliseconds, Awaiter awaiter)
{
using (var resource = new Resource())
{
Stopwatch timer = Stopwatch.StartNew();
do
{
await awaiter;
}
while (timer.ElapsedMilliseconds < milliseconds);
}
Console.WriteLine("Exit SleepAsync");
}
void AwaiterTest()
{
var awaiter = new Awaiter();
var task = SleepAsync(100, awaiter);
awaiter.MoveNext();
Thread.Sleep(500);
//while (awaiter.MoveNext()) ;
awaiter.Dispose();
task.Dispose();
}
public static void Main(string[] args)
{
new Program().AwaiterTest();
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true);
GC.WaitForPendingFinalizers();
Console.ReadLine();
}
// custom awaiter
public class Awaiter :
System.Runtime.CompilerServices.INotifyCompletion,
IDisposable
{
Action _continuation;
readonly CancellationTokenSource _cts = new CancellationTokenSource();
public Awaiter()
{
Console.WriteLine("Awaiter()");
}
~Awaiter()
{
Console.WriteLine("~Awaiter()");
}
public void Cancel()
{
_cts.Cancel();
}
// let the client observe cancellation
public CancellationToken Token { get { return _cts.Token; } }
// resume after await, called upon external event
public bool MoveNext()
{
if (_continuation == null)
return false;
var continuation = _continuation;
_continuation = null;
continuation();
return _continuation != null;
}
// custom Awaiter methods
public Awaiter GetAwaiter()
{
return this;
}
public bool IsCompleted
{
get { return false; }
}
public void GetResult()
{
this.Token.ThrowIfCancellationRequested();
}
// INotifyCompletion
public void OnCompleted(Action continuation)
{
_continuation = continuation;
}
// IDispose
public void Dispose()
{
Console.WriteLine("Awaiter.Dispose()");
if (_continuation != null)
{
Cancel();
MoveNext();
}
}
}
}
}
When it's time to unwind, I request the cancellation inside Awaiter.Dispose and drive the state machine to the next step (if there's a pending continuation). This leads to observing the cancellation inside Awaiter.GetResult (which is called by the compiler-generated code). That throws TaskCanceledException and further unwinds the using statement. So, the Resource gets properly disposed of. Finally, the task transitions to the cancelled state (task.IsCancelled == true).
IMO, this is a more simple and direct approach than installing a custom synchronization context on the current thread. It can be easily adapted for multithreading (some more details here).
This should indeed give you more freedom than with IEnumerator/yield. You could use try/catch inside your coroutine logic, and you can observe exceptions, cancellation and the result directly via the Task object.
Updated, AFAIK there is no analogy for the iterator's generated IDispose, when it comes to async state machine. You really have to drive the state machine to an end when you want to cancel/unwind it. If you want to account for some negligent use of try/catch preventing the cancellation, I think the best you could do is to check if _continuation is non-null inside Awaiter.Cancel (after MoveNext) and throw a fatal exception out-of-the-band (using a helper async void method).
Updated, this has evolved to a blog post:
Asynchronous coroutines with C# 8.0 and IAsyncEnumerable.
It's 2020 and my other answer about await and coroutines is quite outdated by today's C# language standards. C# 8.0 has introduced support for asynchronous streams with new features like:
IAsyncEnumerable
IAsyncEnumerator
await foreach
IAsyncDisposable
await using
To familiarize yourself with the concept of asynchronous streams, I could highly recommend reading "Iterating with Async Enumerables in C# 8", by Stephen Toub.
Together, these new features provide a great base for implementing asynchronous co-routines in C# in a much more natural way.
Wikipedia provides a good explanation of what co-routines (aka corotines) generally are. What I'd like to show here is how co-routines can be async, suspending their execution flow by using await and arbitrary swapping the roles of being producer/consumer to each other, with C# 8.0.
The code fragment below should illustrate the concept. Here we have two co-routines, CoroutineA and CoroutineB which execute cooperatively and asynchronously, by yielding to each other as their pseudo-linear execution flow goes on.
namespace Tests
{
[TestClass]
public class CoroutineProxyTest
{
const string TRACE_CATEGORY = "coroutines";
/// <summary>
/// CoroutineA yields to CoroutineB
/// </summary>
private async IAsyncEnumerable<string> CoroutineA(
ICoroutineProxy<string> coroutineProxy,
[EnumeratorCancellation] CancellationToken token)
{
await using (var coroutine = await coroutineProxy.AsAsyncEnumerator(token))
{
const string name = "A";
var i = 0;
// yielding 1
Trace.WriteLine($"{name} about to yeild: {++i}", TRACE_CATEGORY);
yield return $"{i} from {name}";
// receiving
if (!await coroutine.MoveNextAsync())
{
yield break;
}
Trace.WriteLine($"{name} received: {coroutine.Current}", TRACE_CATEGORY);
// yielding 2
Trace.WriteLine($"{name} about to yeild: {++i}", TRACE_CATEGORY);
yield return $"{i} from {name}";
// receiving
if (!await coroutine.MoveNextAsync())
{
yield break;
}
Trace.WriteLine($"{name} received: {coroutine.Current}", TRACE_CATEGORY);
// yielding 3
Trace.WriteLine($"{name} about to yeild: {++i}", TRACE_CATEGORY);
yield return $"{i} from {name}";
}
}
/// <summary>
/// CoroutineB yields to CoroutineA
/// </summary>
private async IAsyncEnumerable<string> CoroutineB(
ICoroutineProxy<string> coroutineProxy,
[EnumeratorCancellation] CancellationToken token)
{
await using (var coroutine = await coroutineProxy.AsAsyncEnumerator(token))
{
const string name = "B";
var i = 0;
// receiving
if (!await coroutine.MoveNextAsync())
{
yield break;
}
Trace.WriteLine($"{name} received: {coroutine.Current}", TRACE_CATEGORY);
// yielding 1
Trace.WriteLine($"{name} about to yeild: {++i}", TRACE_CATEGORY);
yield return $"{i} from {name}";
// receiving
if (!await coroutine.MoveNextAsync())
{
yield break;
}
Trace.WriteLine($"{name} received: {coroutine.Current}", TRACE_CATEGORY);
// yielding 2
Trace.WriteLine($"{name} about to yeild: {++i}", TRACE_CATEGORY);
yield return $"{i} from {name}";
// receiving
if (!await coroutine.MoveNextAsync())
{
yield break;
}
Trace.WriteLine($"{name} received: {coroutine.Current}", TRACE_CATEGORY);
}
}
/// <summary>
/// Testing CoroutineA and CoroutineB cooperative execution
/// </summary>
[TestMethod]
public async Task Test_Coroutine_Execution_Flow()
{
// Here we execute two cotoutines, CoroutineA and CoroutineB,
// which asynchronously yield to each other
//TODO: test cancellation scenarios
var token = CancellationToken.None;
using (var apartment = new Tests.ThreadPoolApartment())
{
await apartment.Run(async () =>
{
var proxyA = new CoroutineProxy<string>();
var proxyB = new CoroutineProxy<string>();
var listener = new Tests.CategoryTraceListener(TRACE_CATEGORY);
Trace.Listeners.Add(listener);
try
{
// start both coroutines
await Task.WhenAll(
proxyA.Run(token => CoroutineA(proxyB, token), token),
proxyB.Run(token => CoroutineB(proxyA, token), token))
.WithAggregatedExceptions();
}
finally
{
Trace.Listeners.Remove(listener);
}
var traces = listener.ToArray();
Assert.AreEqual(traces[0], "A about to yeild: 1");
Assert.AreEqual(traces[1], "B received: 1 from A");
Assert.AreEqual(traces[2], "B about to yeild: 1");
Assert.AreEqual(traces[3], "A received: 1 from B");
Assert.AreEqual(traces[4], "A about to yeild: 2");
Assert.AreEqual(traces[5], "B received: 2 from A");
Assert.AreEqual(traces[6], "B about to yeild: 2");
Assert.AreEqual(traces[7], "A received: 2 from B");
Assert.AreEqual(traces[8], "A about to yeild: 3");
Assert.AreEqual(traces[9], "B received: 3 from A");
});
}
}
}
}
The test's output looks like this:
coroutines: A about to yeild: 1
coroutines: B received: 1 from A
coroutines: B about to yeild: 1
coroutines: A received: 1 from B
coroutines: A about to yeild: 2
coroutines: B received: 2 from A
coroutines: B about to yeild: 2
coroutines: A received: 2 from B
coroutines: A about to yeild: 3
coroutines: B received: 3 from A
I currently use asynchronous co-routines in some of my automated UI testing scenarios. E.g., I might have an asynchronous test workflow logic that runs on a UI thread (that'd be CouroutineA) and a complimentary workflow that runs on a ThreadPool thread as a part of a [TestMethod] method (that'd be CouroutineB).
Then I could do something like await WaitForUserInputAsync(); yield return true; to synchronize at certain key points of CouroutineA and CouroutineB cooperative execution flow.
Without yield return I'd have to use some form of asynchronous synchronization primitives, like Stephen Toub's AsyncManualResetEvent. I personally feel using co-routines is a more natural way of doing such kind of synchronization.
The code for CoroutineProxy (which drives the execution of co-routines) is still a work-in-progress. It currently uses TPL Dataflow's BufferBlock as a proxy queue to coordinate the asynchronous execution, and I am not sure yet if it is an optimal way of doing that. Currently, this is what it looks like this:
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
#nullable enable
namespace Tests
{
public interface ICoroutineProxy<T>
{
public Task<IAsyncEnumerable<T>> AsAsyncEnumerable(CancellationToken token = default);
}
public static class CoroutineProxyExt
{
public async static Task<IAsyncEnumerator<T>> AsAsyncEnumerator<T>(
this ICoroutineProxy<T> #this,
CancellationToken token = default)
{
return (await #this.AsAsyncEnumerable(token)).GetAsyncEnumerator(token);
}
}
public class CoroutineProxy<T> : ICoroutineProxy<T>
{
readonly TaskCompletionSource<IAsyncEnumerable<T>> _proxyTcs =
new TaskCompletionSource<IAsyncEnumerable<T>>(TaskCreationOptions.RunContinuationsAsynchronously);
public CoroutineProxy()
{
}
private async IAsyncEnumerable<T> CreateProxyAsyncEnumerable(
ISourceBlock<T> bufferBlock,
[EnumeratorCancellation] CancellationToken token)
{
var completionTask = bufferBlock.Completion;
while (true)
{
var itemTask = bufferBlock.ReceiveAsync(token);
var any = await Task.WhenAny(itemTask, completionTask);
if (any == completionTask)
{
// observe completion exceptions if any
await completionTask;
yield break;
}
yield return await itemTask;
}
}
async Task<IAsyncEnumerable<T>> ICoroutineProxy<T>.AsAsyncEnumerable(CancellationToken token)
{
using (token.Register(() => _proxyTcs.TrySetCanceled(), useSynchronizationContext: true))
{
return await _proxyTcs.Task;
}
}
public async Task Run(Func<CancellationToken, IAsyncEnumerable<T>> routine, CancellationToken token)
{
token.ThrowIfCancellationRequested();
var bufferBlock = new BufferBlock<T>();
var proxy = CreateProxyAsyncEnumerable(bufferBlock, token);
_proxyTcs.SetResult(proxy); // throw if already set
try
{
//TODO: do we need to use routine(token).WithCancellation(token) ?
await foreach (var item in routine(token))
{
await bufferBlock.SendAsync(item, token);
}
bufferBlock.Complete();
}
catch (Exception ex)
{
((IDataflowBlock)bufferBlock).Fault(ex);
throw;
}
}
}
}