using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;

namespace Tests
{
    public class CurrentThreadTaskScheduler : TaskScheduler
    {
        protected override void QueueTask(Task task)
        {
            TryExecuteTask(task);
        }

        protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
        {
            return TryExecuteTask(task);
        }

        protected override IEnumerable<Task> GetScheduledTasks()
        {
            return Enumerable.Empty<Task>();
        }

        /// <summary>
        /// Replaces the current task scheduler with an instance of the CurrentThreadTaskScheduler
        /// using reflection.
        /// </summary>
        public void Start()
        {
            var taskSchedulerType = typeof (TaskScheduler);
            var defaultTaskSchedulerField = taskSchedulerType.GetField("s_defaultTaskScheduler", BindingFlags.SetField | BindingFlags.Static | BindingFlags.NonPublic);
            if (defaultTaskSchedulerField != null) defaultTaskSchedulerField.SetValue(null, this);
        }
    }
}