I adopted my implementation of parallel/consumer based on the code in this question
class ParallelConsumer<T> : IDisposable
{
private readonly int _maxParallel;
private readonly Action<T> _action;
private readonly TaskFactory _factory = new TaskFactory();
private CancellationTokenSource _tokenSource;
private readonly BlockingCollection<T> _entries = new BlockingCollection<T>();
private Task _task;
public ParallelConsumer(int maxParallel, Action<T> action)
{
_maxParallel = maxParallel;
_action = action;
}
public void Start()
{
try
{
_tokenSource = new CancellationTokenSource();
_task = _factory.StartNew(
() =>
{
Parallel.ForEach(
_entries.GetConsumingEnumerable(),
new ParallelOptions { MaxDegreeOfParallelism = _maxParallel, CancellationToken = _tokenSource.Token },
(item, loopState) =>
{
Log("Taking" + item);
if (!_tokenSource.IsCancellationRequested)
{
_action(item);
Log("Finished" + item);
}
else
{
Log("Not Taking" + item);
_entries.CompleteAdding();
loopState.Stop();
}
});
},
_tokenSource.Token);
}
catch (OperationCanceledException oce)
{
System.Diagnostics.Debug.WriteLine(oce);
}
}
private void Log(string message)
{
Console.WriteLine(message);
}
public void Stop()
{
Dispose();
}
public void Enqueue(T entry)
{
Log("Enqueuing" + entry);
_entries.Add(entry);
}
public void Dispose()
{
if (_task == null)
{
return;
}
_tokenSource.Cancel();
while (!_task.IsCanceled)
{
}
_task.Dispose();
_tokenSource.Dispose();
_task = null;
}
}
And here is a test code
class Program
{
static void Main(string[] args)
{
TestRepeatedEnqueue(100, 1);
}
private static void TestRepeatedEnqueue(int itemCount, int parallelCount)
{
bool[] flags = new bool[itemCount];
var consumer = new ParallelConsumer<int>(parallelCount,
(i) =>
{
flags[i] = true;
}
);
consumer.Start();
for (int i = 0; i < itemCount; i++)
{
consumer.Enqueue(i);
}
Thread.Sleep(1000);
Debug.Assert(flags.All(b => b == true));
}
}
The test always fails - it always stuck at around 93th-item from the 100 tested. Any idea which part of my code caused this issue, and how to fix it?
You cannot use Parallel.Foreach() with BlockingCollection.GetConsumingEnumerable(), as you have discovered.
For an explanation, see this blog post:
https://devblogs.microsoft.com/pfxteam/parallelextensionsextras-tour-4-blockingcollectionextensions/
Excerpt from the blog:
BlockingCollection’s GetConsumingEnumerable implementation is using BlockingCollection’s internal synchronization which already supports multiple consumers concurrently, but ForEach doesn’t know that, and its enumerable-partitioning logic also needs to take a lock while accessing the enumerable.
As such, there’s more synchronization here than is actually necessary, resulting in a potentially non-negligable performance hit.
[Also] the partitioning algorithm employed by default by both Parallel.ForEach and PLINQ use chunking in order to minimize synchronization costs: rather than taking the lock once per element, it'll take the lock, grab a group of elements (a chunk), and then release the lock.
While this design can help with overall throughput, for scenarios that are focused more on low latency, that chunking can be prohibitive.
That blog also provides the source code for a method called GetConsumingPartitioner() which you can use to solve the problem.
public static class BlockingCollectionExtensions
{
public static Partitioner<T> GetConsumingPartitioner<T>(this BlockingCollection<T> collection)
{
return new BlockingCollectionPartitioner<T>(collection);
}
public class BlockingCollectionPartitioner<T> : Partitioner<T>
{
private BlockingCollection<T> _collection;
internal BlockingCollectionPartitioner(BlockingCollection<T> collection)
{
if (collection == null)
throw new ArgumentNullException("collection");
_collection = collection;
}
public override bool SupportsDynamicPartitions
{
get { return true; }
}
public override IList<IEnumerator<T>> GetPartitions(int partitionCount)
{
if (partitionCount < 1)
throw new ArgumentOutOfRangeException("partitionCount");
var dynamicPartitioner = GetDynamicPartitions();
return Enumerable.Range(0, partitionCount).Select(_ => dynamicPartitioner.GetEnumerator()).ToArray();
}
public override IEnumerable<T> GetDynamicPartitions()
{
return _collection.GetConsumingEnumerable();
}
}
}
The reason for failure is because of the following reason as explained here
The partitioning algorithm employed by default by both
Parallel.ForEach and PLINQ use chunking in order to minimize
synchronization costs: rather than taking the lock once per element,
it'll take the lock, grab a group of elements (a chunk), and then
release the lock.
To get it to work, you can add a method on your ParallelConsumer<T> class to indicate that the adding is completed, as below
public void StopAdding()
{
_entries.CompleteAdding();
}
And now call this method after your for loop , as below
consumer.Start();
for (int i = 0; i < itemCount; i++)
{
consumer.Enqueue(i);
}
consumer.StopAdding();
Otherwise, Parallel.ForEach() would wait for the threshold to be reached so as to grab the chunk and start processing.
Related
Suppose I am provided with an event producer API consisting of Start(), Pause(), and Resume() methods, and an ItemAvailable event. The producer itself is external code, and I have no control over its threading. A few items may still come through after Pause() is called (the producer is actually remote, so items may already be in flight over the network).
Suppose also that I am writing consumer code, where consumption may be slower than production.
Critical requirements are
The consumer event handler must not block the producer thread, and
All events must be processed (no data can be dropped).
I introduce a buffer into the consumer to smooth out some burstiness. But in the case of extended burstiness, I want to call Producer.Pause(), and then Resume() at an appropriate time, to avoid running out of memory at the consumer side.
I have a solution making use of Interlocked to increment and decrement a counter, which is compared to a threshold to decide whether it is time to Pause or Resume.
Question: Is there a better solution than the Interlocked counter (int current in the code below), in terms of efficiency (and elegance)?
Updated MVP (no longer bounces off the limiter):
namespace Experiments
{
internal class Program
{
// simple external producer API for demo purposes
private class Producer
{
public void Pause(int i) { _blocker.Reset(); Console.WriteLine($"paused at {i}"); }
public void Resume(int i) { _blocker.Set(); Console.WriteLine($"resumed at {i}"); }
public async Task Start()
{
await Task.Run
(
() =>
{
for (int i = 0; i < 10000; i++)
{
_blocker.Wait();
ItemAvailable?.Invoke(this, i);
}
}
);
}
public event EventHandler<int> ItemAvailable;
private ManualResetEventSlim _blocker = new(true);
}
private static async Task Main(string[] args)
{
var p = new Producer();
var buffer = Channel.CreateUnbounded<int>(new UnboundedChannelOptions { SingleWriter = true });
int threshold = 1000;
int resumeAt = 10;
int current = 0;
int paused = 0;
p.ItemAvailable += (_, i) =>
{
if (Interlocked.Increment(ref current) >= threshold
&& Interlocked.CompareExchange(ref paused, 0, 1) == 0
) p.Pause(i);
buffer.Writer.TryWrite(i);
};
var processor = Task.Run
(
async () =>
{
await foreach (int i in buffer.Reader.ReadAllAsync())
{
Console.WriteLine($"processing {i}");
await Task.Delay(10);
if
(
Interlocked.Decrement(ref current) < resumeAt
&& Interlocked.CompareExchange(ref paused, 1, 0) == 1
) p.Resume(i);
}
}
);
p.Start();
await processor;
}
}
}
If you are aiming at elegance, you could consider baking the pressure-awareness functionality inside a custom Channel<T>. Below is a PressureAwareUnboundedChannel<T> class that derives from the Channel<T>. It offers all the functionality of the base class, plus it emits notifications when the channel becomes under pressure, and when the pressure is relieved. The notifications are pushed through an IProgress<bool> instance, that emits a true value when the pressure surpasses a specific high-threshold, and a false value when the pressure drops under a specific low-threshold.
public sealed class PressureAwareUnboundedChannel<T> : Channel<T>
{
private readonly Channel<T> _channel;
private readonly int _highPressureThreshold;
private readonly int _lowPressureThreshold;
private readonly IProgress<bool> _pressureProgress;
private int _pressureState = 0; // 0: no pressure, 1: under pressure
public PressureAwareUnboundedChannel(int lowPressureThreshold,
int highPressureThreshold, IProgress<bool> pressureProgress)
{
if (lowPressureThreshold < 0)
throw new ArgumentOutOfRangeException(nameof(lowPressureThreshold));
if (highPressureThreshold < lowPressureThreshold)
throw new ArgumentOutOfRangeException(nameof(highPressureThreshold));
if (pressureProgress == null)
throw new ArgumentNullException(nameof(pressureProgress));
_highPressureThreshold = highPressureThreshold;
_lowPressureThreshold = lowPressureThreshold;
_pressureProgress = pressureProgress;
_channel = Channel.CreateBounded<T>(Int32.MaxValue);
this.Writer = new ChannelWriter(this);
this.Reader = new ChannelReader(this);
}
private class ChannelWriter : ChannelWriter<T>
{
private readonly PressureAwareUnboundedChannel<T> _parent;
public ChannelWriter(PressureAwareUnboundedChannel<T> parent)
=> _parent = parent;
public override bool TryComplete(Exception error = null)
=> _parent._channel.Writer.TryComplete(error);
public override bool TryWrite(T item)
{
bool success = _parent._channel.Writer.TryWrite(item);
if (success) _parent.SignalWriteOrRead();
return success;
}
public override ValueTask<bool> WaitToWriteAsync(
CancellationToken cancellationToken = default)
=> _parent._channel.Writer.WaitToWriteAsync(cancellationToken);
}
private class ChannelReader : ChannelReader<T>
{
private readonly PressureAwareUnboundedChannel<T> _parent;
public ChannelReader(PressureAwareUnboundedChannel<T> parent)
=> _parent = parent;
public override Task Completion => _parent._channel.Reader.Completion;
public override bool CanCount => _parent._channel.Reader.CanCount;
public override int Count => _parent._channel.Reader.Count;
public override bool TryRead(out T item)
{
bool success = _parent._channel.Reader.TryRead(out item);
if (success) _parent.SignalWriteOrRead();
return success;
}
public override ValueTask<bool> WaitToReadAsync(
CancellationToken cancellationToken = default)
=> _parent._channel.Reader.WaitToReadAsync(cancellationToken);
}
private void SignalWriteOrRead()
{
var currentCount = _channel.Reader.Count;
bool underPressure;
if (currentCount > _highPressureThreshold)
underPressure = true;
else if (currentCount <= _lowPressureThreshold)
underPressure = false;
else
return;
int newState = underPressure ? 1 : 0;
int oldState = underPressure ? 0 : 1;
if (Interlocked.CompareExchange(
ref _pressureState, newState, oldState) != oldState) return;
_pressureProgress.Report(underPressure);
}
}
The encapsulated Channel<T> is actually a bounded channel, having capacity equal to the maximum Int32 value, because only bounded channels implement the Reader.Count property.¹
Usage example:
var progress = new Progress<bool>(underPressure =>
{
if (underPressure) Producer.Pause(); else Producer.Resume();
});
var channel = new PressureAwareUnboundedChannel<Item>(500, 1000, progress);
In this example the Producer will be paused when the items stored inside the channel become more than 1000, and it will be resumed when the number of items drops to 500 or less.
The Progress<bool> action is invoked on the context that was captured at the time of the Progress<bool>'s creation. So if you create it on the UI thread of a GUI application, the action will be invoked on the UI thread, otherwise in will be invoked on the ThreadPool. In the later case there will be no protection against overlapping invocations of the Action<bool>. If the Producer class is not thread-safe, you'll have to add synchronization inside the handler. Example:
var progress = new Progress<bool>(underPressure =>
{
lock (Producer) if (underPressure) Producer.Pause(); else Producer.Resume();
});
¹ Actually unbounded channels also support the Count property, unless they are configured with the SingleReader option.
This is relatively straightforward if you realize there are three "steps" in this problem.
The first step ToChannel(Producer) receives messages from the producer.
The next step, PauseAt signals pause() if there are too many pending items in the out panel.
The third step, ResumeAt signals resume() if its input channel has a count less than a threshold.
It's easy to combine all three steps using typical Channel patterns.
producer.ToChannel(token)
.PauseAt(1000,()=>producer.PauseAsync(),token)
.ResumeAt(10,()=>producer.ResumeAsync(),token)
....
Or a single, generic TrafficJam method:
static ChannelReader<T> TrafficJam(this ChannelReader<T> source,
int pauseAt,int resumeAt,
Func<Task> pause,Func<Task> resume,
CancellationToken token=default)
{
return source
.PauseAt(pauseAt,pause,token)
.ResumeAt(resumeAt,resume,token);
}
ToChannel
The first step is relatively straightforward, an unbounded Channel source based from the producer's events.
static ChannelReader<int> ToChannel(this Producer producer,
CancellationToken token=default)
{
Channel<int> channel=Channel.CreateUnbounded();
var writer=channel.Writer;
producer.ItemAvailable += OnItem;
return channel;
void OnItem(object sender, int item)
{
writer.TryWriteAsync(item);
if(token.IsCancellationRequested)
{
producer.ItemAvailable-=OnItem;
writer.Complete();
}
}
}
The only unusual part is using a local function to allow disabling the event handler and completing the output channel when cancellation is requested
That's enough to queue all the incoming items. ToChannel doesn't bother with starting, pausing etc, that's not its job.
PauseAt
The next function, PauseAt, uses a BoundedChannel to implement the threshold. It forwards incoming messages if it can. If the channel can't accept any more messages it calls the pause callback and awaits until it can resume forwarding :
static ChannelReader<T> PauseAt(this ChannelReader<T> source,
int threshold, Func<Task> pause,
CancellationToken token=default)
{
Channel<T> channel=Channel.CreateBounded(threshold);
var writer=channel.Writer;
_ = Task.Run(async ()=>
await foreach(var msg in source.ReadAllAsync(token))
{
if(writer.CanWrite())
{
await writer.WriteAsync(msg);
}
else
{
await pause();
//Wait until we can post again
await writer.WriteAsync(msg);
}
}
},token)
.ContinueWith(t=>writer.TryComplete(t.Exception));
return channel;
}
ResumeAt
The final step, ResumeAt, calls resume() if its input was previously above the threshold and now has fewer items.
If the input isn't bounded, it just forwards all messages.
static ChannelReader<T> ResumeAt(this ChannelReader<T> source,
int resumeAt, Func<Task> resume,
CancellationToken token=default)
{
Channel<T> channel=Channel.CreateUnbounded();
var writer=channel.Writer;
_ = Task.Run(async ()=>{
bool above=false;
await foreach(var msg in source.ReadAllAsync(token))
{
await writer.WriteAsync(msg);
//Do nothing if the source isn't bounded
if(source.CanCount)
{
if(above && source.Count<=resumeAt)
{
await resume();
above=false;
}
above=source.Count>resumeAt;
}
}
},token)
.ContinueWith(t=>writer.TryComplete(t.Exception));
return channel;
}
Since only a single thread is used, we can keep count of the previous count. and whether it was above or below the threshold.
Combining Pause and Resume
Since Pause and Resume work with just channels, they can be combined into a single method :
static ChannelReader<T> TrafficJam(this ChannelReader<T> source,
int pauseAt,int resumeAt,
Func<Task> pause,Func<Task> resume,
CancellationToken token=default)
{
return source.PauseAt(pauseAt,pause,token)
.ResumeAt(resumeAt,resume,token);
}
EDIT: I've updated my examples to use the https://github.com/StephenCleary/AsyncEx library. Still waiting for usable hints.
There are resources, which are identified by strings (for example files, URLs, etc.). I'm looking for a locking mechanism over the resources. I've found 2 different solutions, but each has its problems:
The first is using the ConcurrentDictionary class with AsyncLock:
using Nito.AsyncEx;
using System.Collections.Concurrent;
internal static class Locking {
private static ConcurrentDictionary<string, AsyncLock> mutexes
= new ConcurrentDictionary<string, AsyncLock>();
internal static AsyncLock GetMutex(string resourceLocator) {
return mutexes.GetOrAdd(
resourceLocator,
key => new AsyncLock()
);
}
}
Async usage:
using (await Locking.GetMutex("resource_string").LockAsync()) {
...
}
Synchronous usage:
using (Locking.GetMutex("resource_string").Lock()) {
...
}
This works safely, but the problem is that the dictionary grows larger and larger, and I don't see a thread-safe way to remove items from the dictionary when no one is waiting on a lock. (I also want to avoid global locks.)
My second solution hashes the string to a number between 0 and N - 1, and locks on these:
using Nito.AsyncEx;
using System.Collections.Concurrent;
internal static class Locking {
private const UInt32 BUCKET_COUNT = 4096;
private static ConcurrentDictionary<UInt32, AsyncLock> mutexes
= new ConcurrentDictionary<UInt32, AsyncLock>();
private static UInt32 HashStringToInt(string text) {
return ((UInt32)text.GetHashCode()) % BUCKET_COUNT;
}
internal static AsyncLock GetMutex(string resourceLocator) {
return mutexes.GetOrAdd(
HashStringToInt(resourceLocator),
key => new AsyncLock()
);
}
}
As one can see, the second solution only decreases the probability of collisions, but doesn't avoid them. My biggest fear is that it can cause deadlocks: The main strategy to avoid deadlocks is to always lock items in a specific order. But with this approach, different items can map to the same buckets in different order, like: (A->X, B->Y), (C->Y, D->X). So with this solution one cannot lock on more than one resource safely.
Is there a better solution? (I also welcome critics of the above 2 solutions.)
You could probably improve upon the first solution by removing a lock from the dictionary when it stops being in use. The removed locks could then be added to a small pool, so that the next time you need a lock you just grab one from the pool instead of creating a new one.
Update: Here is an implementation of this idea. It is based on SemaphoreSlims instead of Stephen Cleary's AsyncLocks, because a custom disposable is required in order to remove unused semaphores from the dictionary.
public class MultiLock<TKey>
{
private object Locker { get; } = new object();
private Dictionary<TKey, LockItem> Dictionary { get; }
private Queue<LockItem> Pool { get; }
private int PoolSize { get; }
public MultiLock(int poolSize = 10)
{
Dictionary = new Dictionary<TKey, LockItem>();
Pool = new Queue<LockItem>(poolSize);
PoolSize = poolSize;
}
public WaitResult Wait(TKey key,
int millisecondsTimeout = Timeout.Infinite,
CancellationToken cancellationToken = default)
{
var lockItem = GetLockItem(key);
bool acquired;
try
{
acquired = lockItem.Semaphore.Wait(millisecondsTimeout,
cancellationToken);
}
catch
{
ReleaseLockItem(lockItem, key);
throw;
}
return new WaitResult(this, lockItem, key, acquired);
}
public async Task<WaitResult> WaitAsync(TKey key,
int millisecondsTimeout = Timeout.Infinite,
CancellationToken cancellationToken = default)
{
var lockItem = GetLockItem(key);
bool acquired;
try
{
acquired = await lockItem.Semaphore.WaitAsync(millisecondsTimeout,
cancellationToken).ConfigureAwait(false);
}
catch
{
ReleaseLockItem(lockItem, key);
throw;
}
return new WaitResult(this, lockItem, key, acquired);
}
private LockItem GetLockItem(TKey key)
{
LockItem lockItem;
lock (Locker)
{
if (!Dictionary.TryGetValue(key, out lockItem))
{
if (Pool.Count > 0)
{
lockItem = Pool.Dequeue();
}
else
{
lockItem = new LockItem();
}
Dictionary.Add(key, lockItem);
}
lockItem.UsedCount += 1;
}
return lockItem;
}
private void ReleaseLockItem(LockItem lockItem, TKey key)
{
lock (Locker)
{
lockItem.UsedCount -= 1;
if (lockItem.UsedCount == 0)
{
if (Dictionary.TryGetValue(key, out var stored))
{
if (stored == lockItem) // Sanity check
{
Dictionary.Remove(key);
if (Pool.Count < PoolSize)
{
Pool.Enqueue(lockItem);
}
}
}
}
}
}
internal class LockItem
{
public SemaphoreSlim Semaphore { get; } = new SemaphoreSlim(1);
public int UsedCount { get; set; }
}
public struct WaitResult : IDisposable
{
private MultiLock<TKey> MultiLock { get; }
private LockItem LockItem { get; }
private TKey Key { get; }
public bool LockAcquired { get; }
internal WaitResult(MultiLock<TKey> multiLock, LockItem lockItem, TKey key,
bool acquired)
{
MultiLock = multiLock;
LockItem = lockItem;
Key = key;
LockAcquired = acquired;
}
void IDisposable.Dispose()
{
MultiLock.ReleaseLockItem(LockItem, Key);
LockItem.Semaphore.Release();
}
}
}
Usage example:
var multiLock = new MultiLock<string>();
using (await multiLock.WaitAsync("SomeKey"))
{
//...
}
The default pool size for unused semaphores is 10. The optimal value should be the number of the concurrent workers that are using the MultiLock instance.
I did a performance test in my PC, and 10 workers were able to acquire the lock asynchronously 500,000 times in total per second (20 different string identifiers were used).
I've got this pattern for preventing calling into an async method before it has had a chance to complete previously.
My solution involving needing a flag, and then needing to lock around the flag, feels pretty verbose. Is there a more natural way of achieving this?
public class MyClass
{
private object SyncIsFooRunning = new object();
private bool IsFooRunning { get; set;}
public async Task FooAsync()
{
try
{
lock(SyncIsFooRunning)
{
if(IsFooRunning)
return;
IsFooRunning = true;
}
// Use a semaphore to enforce maximum number of Tasks which are able to run concurrently.
var semaphoreSlim = new SemaphoreSlim(5);
var trackedTasks = new List<Task>();
for(int i = 0; i < 100; i++)
{
await semaphoreSlim.WaitAsync();
trackedTasks.Add(Task.Run(() =>
{
// DoTask();
semaphoreSlim.Release();
}));
}
// Using await makes try/catch/finally possible.
await Task.WhenAll(trackedTasks);
}
finally
{
lock(SyncIsFooRunning)
{
IsFooRunning = false;
}
}
}
}
As noted in the comments, you can use Interlocked.CompareExchange() if you prefer:
public class MyClass
{
private int _flag;
public async Task FooAsync()
{
try
{
if (Interlocked.CompareExchange(ref _flag, 1, 0) == 1)
{
return;
}
// do stuff
}
finally
{
Interlocked.Exchange(ref _flag, 0);
}
}
}
That said, I think it's overkill. Nothing wrong with using lock in this type of scenario, especially if you don't expect a lot of contention on the method. What I do think would be better is to wrap the method so that the caller can always await on the result, whether a new asynchronous operation was started or not:
public class MyClass
{
private readonly object _lock = new object();
private Task _task;
public Task FooAsync()
{
lock (_lock)
{
return _task != null ? _task : (_task = FooAsyncImpl());
}
}
public async Task FooAsyncImpl()
{
try
{
// do async stuff
}
finally
{
lock (_lock) _task = null;
}
}
}
Finally, in the comments, you say this:
Seems a bit odd that all the return types are still valid for Task?
Not clear to me what you mean by that. In your method, the only valid return types would be void and Task. If your return statement(s) returned an actual value, you'd have to use Task<T> where T is the type returned by the return statement(s).
I would like to have a custom thread pool satisfying the following requirements:
Real threads are preallocated according to the pool capacity. The actual work is free to use the standard .NET thread pool, if needed to spawn concurrent tasks.
The pool must be able to return the number of idle threads. The returned number may be less than the actual number of the idle threads, but it must not be greater. Of course, the more accurate the number the better.
Queuing work to the pool should return a corresponding Task, which should place nice with the Task based API.
NEW The max job capacity (or degree of parallelism) should be adjustable dynamically. Trying to reduce the capacity does not have to take effect immediately, but increasing it should do so immediately.
The rationale for the first item is depicted below:
The machine is not supposed to be running more than N work items concurrently, where N is relatively small - between 10 and 30.
The work is fetched from the database and if K items are fetched then we want to make sure that there are K idle threads to start the work right away. A situation where work is fetched from the database, but remains waiting for the next available thread is unacceptable.
The last item also explains the reason for having the idle thread count - I am going to fetch that many work items from the database. It also explains why the reported idle thread count must never be higher than the actual one - otherwise I might fetch more work that can be immediately started.
Anyway, here is my implementation along with a small program to test it (BJE stands for Background Job Engine):
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public class BJEThreadPool
{
private sealed class InternalTaskScheduler : TaskScheduler
{
private int m_idleThreadCount;
private readonly BlockingCollection<Task> m_bus;
public InternalTaskScheduler(int threadCount, BlockingCollection<Task> bus)
{
m_idleThreadCount = threadCount;
m_bus = bus;
}
public void RunInline(Task task)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
TryExecuteTask(task);
}
catch
{
// The action is responsible itself for the error handling, for the time being...
}
Interlocked.Increment(ref m_idleThreadCount);
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#region Overrides of TaskScheduler
protected override void QueueTask(Task task)
{
m_bus.Add(task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
return TryExecuteTask(task);
}
protected override IEnumerable<Task> GetScheduledTasks()
{
throw new NotSupportedException();
}
#endregion
public void DecrementIdleThreadCount()
{
Interlocked.Decrement(ref m_idleThreadCount);
}
}
private class ThreadContext
{
private readonly InternalTaskScheduler m_ts;
private readonly BlockingCollection<Task> m_bus;
private readonly CancellationTokenSource m_cts;
public readonly Thread Thread;
public ThreadContext(string name, InternalTaskScheduler ts, BlockingCollection<Task> bus, CancellationTokenSource cts)
{
m_ts = ts;
m_bus = bus;
m_cts = cts;
Thread = new Thread(Start)
{
IsBackground = true,
Name = name
};
Thread.Start();
}
private void Start()
{
try
{
foreach (var task in m_bus.GetConsumingEnumerable(m_cts.Token))
{
m_ts.RunInline(task);
}
}
catch (OperationCanceledException)
{
}
m_ts.DecrementIdleThreadCount();
}
}
private readonly InternalTaskScheduler m_ts;
private readonly CancellationTokenSource m_cts = new CancellationTokenSource();
private readonly BlockingCollection<Task> m_bus = new BlockingCollection<Task>();
private readonly List<ThreadContext> m_threadCtxs = new List<ThreadContext>();
public BJEThreadPool(int threadCount)
{
m_ts = new InternalTaskScheduler(threadCount, m_bus);
for (int i = 0; i < threadCount; ++i)
{
m_threadCtxs.Add(new ThreadContext("BJE Thread " + i, m_ts, m_bus, m_cts));
}
}
public void Terminate()
{
m_cts.Cancel();
foreach (var t in m_threadCtxs)
{
t.Thread.Join();
}
}
public Task Run(Action<CancellationToken> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public Task Run(Action action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public int IdleThreadCount
{
get { return m_ts.IdleThreadCount; }
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 32;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(cancellationToken =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
tcs.Task.Wait(cancellationToken);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("Cancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("Error All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("Abort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}
It is a very simple implementation and it appears to be working. However, there is a problem - the BJEThreadPool.Run method does not accept asynchronous methods. I.e. my implementation does not allow me to add the following overloads:
public Task Run(Func<CancellationToken, Task> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
public Task Run(Func<Task> action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
The pattern I use in InternalTaskScheduler.RunInline does not work in this case.
So, my question is how to add the support for asynchronous work items? I am fine with changing the entire design as long as the requirements outlined at the beginning of the post are upheld.
EDIT
I would like to clarify the intented usage of the desired pool. Please, observe the following code:
if (pool.IdleThreadCount == 0)
{
return;
}
foreach (var jobData in FetchFromDB(pool.IdleThreadCount))
{
pool.Run(CreateJobAction(jobData));
}
Notes:
The code is going to be run periodically, say every 1 minute.
The code is going to be run concurrently by multiple machines watching the same database.
FetchFromDB is going to use the technique described in Using SQL Server as a DB queue with multiple clients to atomically fetch and lock the work from the DB.
CreateJobAction is going to invoke the code denoted by jobData (the job code) and close the work upon the completion of that code. The job code is out of my control and it could be pretty much anything - heavy CPU bound code or light asynchronous IO bound code, badly written synchronous IO bound code or a mix of all. It could run for minutes and it could run for hours. Closing the work is my code and it would by asynchronous IO bound code. Because of this, the signature of the returned job action is that of an asynchronous method.
Item 2 underlines the importance of correctly identifying the amount of idle threads. If there are 900 pending work items and 10 agent machines I cannot allow an agent to fetch 300 work items and queue them on the thread pool. Why? Because, it is most unlikely that the agent will be able to run 300 work items concurrently. It will run some, sure enough, but others will be waiting in the thread pool work queue. Suppose it will run 100 and let 200 wait (even though 100 is probably far fetched). This wields 3 fully loaded agents and 7 idle ones. But only 300 work items out of 900 are actually being processed concurrently!!!
My goal is to maximize the spread of the work amongst the available agents. Ideally, I should evaluate the load of an agent and the "heaviness" of the pending work, but it is a formidable task and is reserved for the future versions. Right now, I wish to assign each agent the max job capacity with the intention to provide the means to increase/decrease it dynamically without restarting the agents.
Next observation. The work can take quite a long time to run and it could be all synchronous code. As far as I understand it is undesirable to utilize thread pool threads for such kind of work.
EDIT 2
There is a statement that TaskScheduler is only for the CPU bound work. But what if I do not know the nature of the work? I mean it is a general purpose Background Job Engine and it runs thousands of different kinds of jobs. I do not have means to tell "that job is CPU bound" and "that on is synchronous IO bound" and yet another one is asynchronous IO bound. I wish I could, but I cannot.
EDIT 3
At the end, I do not use the SemaphoreSlim, but neither do I use the TaskScheduler - it finally trickled down my thick skull that it is unappropriate and plain wrong, plus it makes the code overly complex.
Still, I failed to see how SemaphoreSlim is the way. The proposed pattern:
public async Task Enqueue(Func<Task> taskGenerator)
{
await semaphore.WaitAsync();
try
{
await taskGenerator();
}
finally
{
semaphore.Release();
}
}
Expects taskGenerator either be an asynchronous IO bound code or open a new thread otherwise. However, I have no means to determine whether the work to be executed is one or another. Plus, as I have learned from SemaphoreSlim.WaitAsync continuation code if the semaphore is unlocked, the code following the WaitAsync() is going to run on the same thread, which is not very good for me.
Anyway, below is my implementation, in case anyone fancies. Unfortunately, I am yet to understand how to reduce the pool thread count dynamically, but this is a topic for another question.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public interface IBJEThreadPool
{
void SetThreadCount(int threadCount);
void Terminate();
Task Run(Action action);
Task Run(Action<CancellationToken> action);
Task Run(Func<Task> action);
Task Run(Func<CancellationToken, Task> action);
int IdleThreadCount { get; }
}
public class BJEThreadPool : IBJEThreadPool
{
private interface IActionContext
{
Task Run(CancellationToken ct);
TaskCompletionSource<object> TaskCompletionSource { get; }
}
private class ActionContext : IActionContext
{
private readonly Action m_action;
public ActionContext(Action action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action();
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class CancellableActionContext : IActionContext
{
private readonly Action<CancellationToken> m_action;
public CancellableActionContext(Action<CancellationToken> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action(ct);
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncActionContext : IActionContext
{
private readonly Func<Task> m_action;
public AsyncActionContext(Func<Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action();
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncCancellableActionContext : IActionContext
{
private readonly Func<CancellationToken, Task> m_action;
public AsyncCancellableActionContext(Func<CancellationToken, Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action(ct);
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private readonly CancellationTokenSource m_ctsTerminateAll = new CancellationTokenSource();
private readonly BlockingCollection<IActionContext> m_bus = new BlockingCollection<IActionContext>();
private readonly LinkedList<Thread> m_threads = new LinkedList<Thread>();
private int m_idleThreadCount;
private static int s_threadCount;
public BJEThreadPool(int threadCount)
{
ReserveAdditionalThreads(threadCount);
}
private void ReserveAdditionalThreads(int n)
{
for (int i = 0; i < n; ++i)
{
var index = Interlocked.Increment(ref s_threadCount) - 1;
var t = new Thread(Start)
{
IsBackground = true,
Name = "BJE Thread " + index
};
Interlocked.Increment(ref m_idleThreadCount);
t.Start();
m_threads.AddLast(t);
}
}
private void Start()
{
try
{
foreach (var actionContext in m_bus.GetConsumingEnumerable(m_ctsTerminateAll.Token))
{
RunWork(actionContext).Wait();
}
}
catch (OperationCanceledException)
{
}
catch
{
// Should never happen - log the error
}
Interlocked.Decrement(ref m_idleThreadCount);
}
private async Task RunWork(IActionContext actionContext)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
var task = actionContext.Run(m_ctsTerminateAll.Token);
if (task != null)
{
await task;
}
actionContext.TaskCompletionSource.SetResult(null);
}
catch (OperationCanceledException)
{
actionContext.TaskCompletionSource.TrySetCanceled();
}
catch (Exception exc)
{
actionContext.TaskCompletionSource.TrySetException(exc);
}
Interlocked.Increment(ref m_idleThreadCount);
}
private Task PostWork(IActionContext actionContext)
{
m_bus.Add(actionContext);
return actionContext.TaskCompletionSource.Task;
}
#region Implementation of IBJEThreadPool
public void SetThreadCount(int threadCount)
{
if (threadCount > m_threads.Count)
{
ReserveAdditionalThreads(threadCount - m_threads.Count);
}
else if (threadCount < m_threads.Count)
{
throw new NotSupportedException();
}
}
public void Terminate()
{
m_ctsTerminateAll.Cancel();
foreach (var t in m_threads)
{
t.Join();
}
}
public Task Run(Action action)
{
return PostWork(new ActionContext(action));
}
public Task Run(Action<CancellationToken> action)
{
return PostWork(new CancellableActionContext(action));
}
public Task Run(Func<Task> action)
{
return PostWork(new AsyncActionContext(action));
}
public Task Run(Func<CancellationToken, Task> action)
{
return PostWork(new AsyncCancellableActionContext(action));
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#endregion
}
public static class Extensions
{
public static Task WithCancellation(this Task task, CancellationToken token)
{
return task.ContinueWith(t => t.GetAwaiter().GetResult(), token);
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 16;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(async ct =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
await tcs.Task.WithCancellation(ct);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("ancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("rror All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("bort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}
Asynchronous "work items" are often based on async IO. Async IO does not use threads while it runs. Task schedulers are used to execute CPU work (tasks based on a delegate). The concept TaskScheduler does not apply. You cannot use a custom TaskScheduler to influence what async code does.
Make your work items throttle themselves:
static SemaphoreSlim sem = new SemaphoreSlim(maxDegreeOfParallelism); //shared object
async Task MyWorkerFunction()
{
await sem.WaitAsync();
try
{
MyWork();
}
finally
{
sem.Release();
}
}
As mentioned in another answer by usr you can't do this with a TaskScheduler as that is only for CPU bound work, not limiting the level of parallelization of all types of work, whether parallel or not. He also shows you how you can use a SemaphoreSlim to asynchronously limit the degrees of parallelism.
You can expand on this to generalize these concepts in a few ways. The one that seems like it would be the most beneficial to you would be to create a special type of queue that takes operations that return a Task and executes them in such a way that a given max degree of parallelization is achieved.
public class FixedParallelismQueue
{
private SemaphoreSlim semaphore;
public FixedParallelismQueue(int maxDegreesOfParallelism)
{
semaphore = new SemaphoreSlim(maxDegreesOfParallelism,
maxDegreesOfParallelism);
}
public async Task<T> Enqueue<T>(Func<Task<T>> taskGenerator)
{
await semaphore.WaitAsync();
try
{
return await taskGenerator();
}
finally
{
semaphore.Release();
}
}
public async Task Enqueue(Func<Task> taskGenerator)
{
await semaphore.WaitAsync();
try
{
await taskGenerator();
}
finally
{
semaphore.Release();
}
}
}
This allows you to create a queue for your application (you can even have several separate queues if you want) that has a fixed degree of parallelization. You can then provide operations returning a Task when they complete and the queue will schedule it when it can and return a Task representing when that unit of work has finished.
I have the following code:
static void Main(string[] args)
{
TaskExecuter.Execute();
}
class Task
{
int _delay;
private Task(int delay) { _delay = delay; }
public void Execute() { Thread.Sleep(_delay); }
public static IEnumerable GetAllTasks()
{
Random r = new Random(4711);
for (int i = 0; i < 10; i++)
yield return new Task(r.Next(100, 5000));
}
}
static class TaskExecuter
{
public static void Execute()
{
foreach (Task task in Task.GetAllTasks())
{
task.Execute();
}
}
}
I need to change the loop in Execute method to paralle with multiple threads, I tried the following, but it isn't working since GetAllTasks returns IEnumerable and not a list
Parallel.ForEach(Task.GetAllTasks(), task =>
{
//Execute();
});
Parallel.ForEach works with IEnumerable<T>, so adjust your GetAllTasks to return IEnumerable<Task>.
Also .net has widely used Task class, I would avoid naming own class like that to avoid confusion.
Parallel.ForEach takes an IEnumerable<TSource>, so your code should be fine. However, you need to perform the Execute call on the task instance that is passed as parameter to your lambda statement.
Parallel.ForEach(Task.GetAllTasks(), task =>
{
task.Execute();
});
This can also be expressed as a one-line lambda expression:
Parallel.ForEach(Task.GetAllTasks(), task => task.Execute());
There is also another subtle bug in your code that you should pay attention to. Per its internal implementation, Parallel.ForEach may enumerate the elements of your sequence in parallel. However, you are calling an instance method of the Random class in your enumerator, which is not thread-safe, possibly leading to race issues. The easiest way to work around this would be to pre-populate your sequence as a list:
Parallel.ForEach(Task.GetAllTasks().ToList(), task => task.Execute());
This worked on my linqpad. I just renamed your Task class to Work and also returned an IEnumerable<T> from GetAllTasks:
class Work
{
int _delay;
private Work(int delay) { _delay = delay; }
public void Execute() { Thread.Sleep(_delay); }
public static IEnumerable<Work> GetAllTasks()
{
Random r = new Random(4711);
for (int i = 0; i < 10; i++)
yield return new Work(r.Next(100, 5000));
}
}
static class TaskExecuter
{
public static void Execute()
{
foreach (Work task in Work.GetAllTasks())
{
task.Execute();
}
}
}
void Main()
{
System.Threading.Tasks.Parallel.ForEach(Work.GetAllTasks(), new Action<Work>(task =>
{
//Execute();
}));
}