Wednesday, September 17, 2014

Linq to Expression Trees

MarkdownPad Document

If you are working with expressions trees in .Net, chances are that you have come across the need for inspecting the nodes of the expression tree.

The idea here is to flatten the expression tree into an IEnumerable<Expression>.

/// <summary>
/// Extends the <see cref="Expression"/> class.
/// </summary>
public static class ExpressionExtensions
{
    /// <summary>
    /// Flattens the <paramref name="expression"/> into an <see cref="IEnumerable{T}"/>.
    /// </summary>
    /// <param name="expression">The target <see cref="Expression"/>.</param>
    /// <returns>The <see cref="Expression"/> represented as a list of sub expressions.</returns>
    public static IEnumerable<Expression> AsEnumerable(this Expression expression)
    {
        var flattener = new ExpressionTreeFlattener();
        return flattener.Flatten(expression);
    }

    private class ExpressionTreeFlattener : ExpressionVisitor
    {
        private readonly ICollection<Expression> nodes = new Collection<Expression>();

        public IEnumerable<Expression> Flatten(Expression expression)
        {
            Visit(expression);
            return nodes;
        }

        public override Expression Visit(Expression node)
        {
            nodes.Add(node);
            return base.Visit(node);
        }
    }
}

Now we can take any expression tree and convert it into a list of expressions and just apply Linq to Objects.

Expression<Func<int>> expression = () => 42;
var result = expression.AsEnumerable().FirstOrDefault(e => e is ConstantExpression);
Console.WriteLine(result);

Output:

42  (ConstantExpression)

Happy Linqing!

Thursday, September 11, 2014

Async/Await and code coverage

Async

First of all I should say that I think good code coverage is a good thing. What "good coverage" means is probably something that can be discussed for days, but I like mine in the high 90's. I simply don't like shipping code that has never been executed. Call me paranoid, but that is how I feel about it :)

I have been playing around with the async/await features in C#5 for a while now and finally I was able to put it to good use in a new project that I am working on. Yes, that's right, we were actually given permission to use .Net 4.5 (Damn you Windows XP).

Okay, enough smalltalk (no pun intended) , lets dig into some code.

public interface IFoo
{
    Task<int> ExecuteAsync();
}

public class Foo : IFoo
{
    public async Task<int> ExecuteAsync()
    {
        await Task.Delay(1000);
        return 42;
    }
}     

Simple enough, we mimic the behavior of a long running method by awaiting a delay of 1000 milliseconds.

Let's write a test that executes the method.

[Fact]
public async Task LongRunningTest()
{
    var foo = new Foo();
    var result = await foo.ExecuteAsync();
    Assert.Equal(42, result);
}

The code coverage for the LongRunningTest reports 100% as expected.

Note: We now are essentially looking at the code coverage for the test method, but remember that the ExecuteAsync method could just as easily be used deep down in some business logic where it would make more sense to actually mock the IFoo interface.

Now we try the same thing using a mock

[Fact]
public async Task MockTest()
{
    var fooMock = new Mock<IFoo>();
    fooMock.Setup(m => m.ExecuteAsync()).ReturnsAsync(42);
    var result = await fooMock.Object.ExecuteAsync();
    Assert.Equal(42, result);
}

Running the MockTest with code coverage shows 85% and looking at the results it reports that some mysterious MoveNext method is not covered. So where does this method come from? We most certainly did not write it. As it turns out, the compiler did when it encountered the await keyword.

To get a better understanding of what is going on, we need to take a look at the actual code that gets compiled.

Looking at the LongRunningTest method in Reflector (or ILSpy) shows that in addition to the method itself, there is a corresponding state machine that takes care of the gory details with regards to asynchronous execution.

[CompilerGenerated]
private struct <LongRunningTest>d__0 : IAsyncStateMachine
{
    public int <>1__state;
    public Tests <>4__this;
    public AsyncTaskMethodBuilder <>t__builder;
    private object <>t__stack;
    private TaskAwaiter<int> <>u__$awaiter3;
    public Foo <foo>5__1;
    public int <result>5__2;

    private void MoveNext()
    {
        try
        {
            TaskAwaiter<int> awaiter;
            bool flag = true;
            switch (this.<>1__state)
            {
                case -3:
                    goto Label_00C3;

                case 0:
                    break;

                default:
                    this.<foo>5__1 = new Foo();
                    awaiter = this.<foo>5__1.ExecuteAsync().GetAwaiter();
                    if (awaiter.IsCompleted)
                    {
                        goto Label_0082;
                    }
                    this.<>1__state = 0;
                    this.<>u__$awaiter3 = awaiter;
                    this.<>t__builder.AwaitUnsafeOnCompleted<TaskAwaiter<int>, Tests.<LongRunningTest>d__0>(ref awaiter, ref this);
                    flag = false;
                    return;
            }
            awaiter = this.<>u__$awaiter3;
            this.<>u__$awaiter3 = new TaskAwaiter<int>();
            this.<>1__state = -1;
        Label_0082:
            int introduced6 = awaiter.GetResult();
            awaiter = new TaskAwaiter<int>();
            int num2 = introduced6;
            this.<result>5__2 = num2;
            Assert.Equal<int>(0x2a, this.<result>5__2);
        }
        catch (Exception exception)
        {
            this.<>1__state = -2;
            this.<>t__builder.SetException(exception);
            return;
        }
    Label_00C3:
        this.<>1__state = -2;
        this.<>t__builder.SetResult();
    }

    [DebuggerHidden]
    private void SetStateMachine(IAsyncStateMachine param0)
    {
        this.<>t__builder.SetStateMachine(param0);
    }
}

Without going through the code in detail (very few people can, Jon Skeet definitely being one of them), we can at least see that there is a MoveNext method and that is exactly the method that did not get covered in the MockTest test method. The question is why did it get invoked in the LongRunningTest and not invoked in the MockTest?

The reason that is did not execute is that the asynchronous method was already completed at the time it was invoked. In the Foo class we simulated a delay that caused what is called continuation and in turn invocation of the state machine's MoveNext method.

If we take a look at the ReturnsAsync method from the Moq library, we see the following:

public static IReturnsResult<TMock> ReturnsAsync<TMock, TResult>(this IReturns<TMock, Task<TResult>> mock, TResult value) where TMock : class
{
  TaskCompletionSource<TResult> completionSource = new TaskCompletionSource<TResult>();
  completionSource.SetResult(value);
  return mock.Returns(completionSource.Task);
}    

The ReturnsAsync method uses the TaskCompletionSource<TResult> class and the SetResult method transitions the state of the underlying task to completed. We could get around this issue by creating an extension method that delays the asynchronous method to cause continuation, but that would definitely hurt the duration of test execution.

What we need is some way to create a task that is not completed even if it actually is.

The await keyword expects something awaitable and what does it take for a class to be awaitable?

The following example shows the simplest thing that can be awaited and return a value.

public class CustomAwaitable
{
    public CustomAwaiter GetAwaiter()
    {
        Console.WriteLine("GetAwaiter");
        return new CustomAwaiter();
    }
}

public class CustomAwaiter : INotifyCompletion
{
    public int GetResult()
    {
        Console.WriteLine("GetResult");
        return 42;
    }

    public bool IsCompleted
    {
        get
        {
            Console.WriteLine("IsCompleted (Get)");
            return false;
        }
    }

    public void OnCompleted(Action continuation)
    {
        Console.WriteLine("OnCompleted");
        continuation();
    }
}

There is no reason that this could not have been implemented in the same class, but since this sort of resembles the relationship between the IEnumerable<T> and IEnumerator<T> interface, I found it easier to read if they were split into to classes.

The interesting methods here are the IsCompleted method and the OnCompleted method.

The IsCompleted method is pretty self explanatory and simply indicates if the value is available via the GetResult method. The reason we return false here is that we want to force continuation.

The OnCompleted method is not so self explanatory as one might think that it gets executed when something is completed. What this method actually does is attaching a continuation if the form of a delegate to be invoked once this await operation is completed. This method is never executed unless the IsCompleted property returns false.

Lets create a test for this and take a look at the output.

[Fact]
public async Task CustomAwaiterTest()
{
    var awaitable = new CustomAwaitable();
    int result = await awaitable;
    Assert.Equal(42, result);
}  

The output from this test shows the methods involved.

GetAwaiter
IsCompleted (Get)
OnCompleted
GetResult

Now lets change the IsCompleted property to return true

public bool IsCompleted
{
    get
    {
        Console.WriteLine("IsCompleted (Get)");
        return true;
    }
}

Test output

GetAwaiter
IsCompleted (Get)
GetResult

As we can see the OnCompleted method is never executed as it is no need to attach a continuation.

With a basic understanding of how this works, we can move on to solve the actual problem. So we need to create an awaitable that is not completed even if it actually is completed. The solution is to create a decorator that wrappes an existing awaitable and always return false in the IsCompleted property to ensure continuation.

    /// <summary>
/// Represents an awaitable who's awaiter has not yet completed.
/// </summary>
/// <typeparam name="TResult">The result type.</typeparam>
public class IncompleteAwaitable<TResult>
{
    private TaskAwaiter<TResult> taskAwaiter;

    /// <summary>
    /// Initializes a new instance of the <see cref="IncompleteAwaitable{T}"/> class.
    /// </summary>
    /// <param name="task">The <see cref="Task{TResult}"/> wrapped by this <see cref="IncompleteAwaitable{T}"/>.</param>
    public IncompleteAwaitable(Task<TResult> task)
    {
        taskAwaiter = task.GetAwaiter();
    }

    ///// <summary>
    ///// Gets an awaiter used to await this <see cref="Task{TResult}"/>
    ///// </summary>
    ///// <returns></returns>
    public IncompleteAwaiter<TResult> GetAwaiter()
    {
        return new IncompleteAwaiter<TResult>(taskAwaiter);
    }
}

/// <summary>
/// An awaiter that has never completed regardsless of the 
/// state of the underlying <see cref="TaskAwaiter{TResult}"/>/>
/// </summary>
/// <typeparam name="TResult">The result type.</typeparam>
public class IncompleteAwaiter<TResult> : INotifyCompletion
{
    private readonly TaskAwaiter<TResult> taskAwaiter;

    /// <summary>
    /// Initializes a new instance of the <see cref="IncompleteAwaiter{TResult}"/> class.
    /// </summary>
    /// <param name="taskAwaiter">The underlying <see cref="TaskAwaiter{TResult}"/></param>
    public IncompleteAwaiter(TaskAwaiter<TResult> taskAwaiter)
    {
        this.taskAwaiter = taskAwaiter;
    }

    /// <summary>
    /// Attaches a continuation to the underlying <see cref="TaskAwaiter{TResult}"/>. 
    /// </summary>
    /// <param name="continuation">The continuation delegate to be attached.</param>
    public void OnCompleted(Action continuation)
    {          
        taskAwaiter.OnCompleted(continuation);
    }

    /// <summary>
    /// Gets a value indicating whether this awaiter is completed.
    /// </summary>
    /// <remarks>
    /// This property will always return false to ensure continuation.
    /// </remarks>
    public bool IsCompleted
    {
        get
        {
            return false;
        }
    }

    /// <summary>
    /// Gets the result from the underlying <see cref="TaskAwaiter{TResult}"/>.
    /// </summary>
    /// <returns>The completed <typeparamref name="TResult"/> value.</returns>
    public TResult GetResult()
    {
        return taskAwaiter.GetResult();
    }
}

The IncompleteAwaitable<TResult> simply takes an existing Task<TResult> and by using the GetAwaiter method we get the awaiter from the task and we pass that to the IncompleteAwaiter<TResult> class (decorator).

With this we can now force continuation even from a task that has completed.

[Fact]
public async Task IncompleteAwaitableTest()
{
    //This creates a task that has completed.
    var task = Task.FromResult(42);            
    var awaiter = new IncompleteAwaitable<int>(task);
    var result = await awaiter;
    Assert.Equal(42, result);
}

Looking at the code coverage for this test shows that all the code has executed including the MoveNext method from the state machine.

To make this a little easier to use we can create an extension method that allows us to create an incomplete task from en existing task.

/// <summary>
/// Extends the <see cref="Task{TResult}"/> class.
/// </summary>
public static class TaskExtensions
{        
    /// <summary>
    /// Creates an "incomplete" <see cref="Task{TResult}"/> from 
    /// the given <paramref name="task"/>. 
    /// </summary>
    /// <typeparam name="TResult">The result type returned from the task.</typeparam>
    /// <param name="task">The target <see cref="Task{TResult}"/>.</param>
    /// <returns>A new <see cref="Task{TResult}"/> where the <see cref="Task.IsCompleted"/> is set to false.</returns>
    public static Task<TResult> GetIncompleteTask<TResult>(this Task<TResult> task)
    {
        var incompleteAwaitable = new IncompleteAwaitable<TResult>(task);
        return Task<TResult>.Run(async () => await incompleteAwaitable);
    }
}

We can now create an incomplete task from an existing task.

[Fact]
public async Task IncompleteTaskTest()
{
    var task = Task.FromResult(42);
    var result = await task.GetIncompleteTask();
    Assert.Equal(42, result);
}

And code coverage for this test is still 100% as expected.

Using the Moq library we can create a mock like this.

[Fact]
public async Task IncompleteMockedTask()
{
    var mock = new Mock<IFoo>();
    mock.Setup(m => m.ExecuteAsync()).Returns(() => Task.FromResult(42).GetIncompleteTask());
    var result = await mock.Object.ExecuteAsync();
    Assert.Equal(42, result);
}  

Please note that it is important to use the Returns overload that takes a value factory. Otherwise, Moq will evaluate the task and hence prevent the continuation in the test.

We could simplify this even further by creating an extension method specially tailored for Moq.

/// <summary>
/// Extends the <see cref="IReturns{TMock,TResult}"/> interface.
/// </summary>
public static class MoqExtensions
{
    /// <summary>
    /// Specifies the return value from the method as an "incomplete" <see cref="Task{TResult}"/>.
    /// </summary>
    /// <typeparam name="TMock">The mocked type.</typeparam>
    /// <typeparam name="TResult">The mocked method result type.</typeparam>
    /// <param name="mock">The target <see cref="IReturns{TMock,TResult}"/> instance.</param>
    /// <param name="value">The value to be returned from the method.</param>
    /// <returns>An <see cref="IReturnsResult{TMock}"/> instance.</returns>
    public static IReturnsResult<TMock> ReturnsIncompleteAsync<TMock, TResult>(this IReturns<TMock, Task<TResult>> mock, TResult value) where TMock : class
    {
        var task = Task.FromResult(value);
        return mock.Returns(task.GetIncompleteTask);                   
    }
} 

We can now mock an asynchronous method and still enforce continuation and hence 100% coverage.

[Fact]
public async Task IncompleteMockedTaskUsingExtensionMethod()
{
    var mock = new Mock<IFoo>();
    mock.Setup(m => m.ExecuteAsync()).ReturnsIncompleteAsync(42);
    var result = await mock.Object.ExecuteAsync();
    Assert.Equal(42, result);
}

Happy testing!!