Unit Testing with a Mock Entity Framework DbContext and Fake DbSets

I'm a big fan of the Repository pattern, and use it hand-in-hand with dependency injection and the Query Object pattern. Lately I've been developing with Entity Framework, and using dependency injection to pass the DbContext instances into my repository classes. But one of the issues I ran into while trying to unit test one of my repository classes is that I couldn't figure out how to mock the DbSets that my DbContext exposes. I could create a mock of the DbContext, but was having trouble setting it up to return fake collections of entities. One of the problems is that DbSet has no constructor (it's created via a factory method). After much research and scouring of the web, here are the steps I learned to accomplish this.  

Step One: Create a Fake DbSet Class That Implements IDbSet
The DbContext exposes entities in DbSets. DbSet, in turn, implements the IDbSet interface. So we can create a class for use with our unit tests that, unlike DbSet, can be instantiated on its own.

I found many examples of what were essentially the class shown below:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
 
namespace MyUnitTest.DatabaseHelpers
{
    public class FakeDbSet<T> : IDbSet<T> where T : class
    {
        readonly HashSet<T> _data;
        readonly IQueryable _query;
 
        public FakeDbSet()
        {
            _data = new HashSet<T>();
            _query = _data.AsQueryable();
        }
 
        public T Add(T entity)
        {
            _data.Add(entity);
            return entity;
        }
 
        public T Attach(T entity)
        {
            _data.Add(entity);
            return entity;
        }
 
        public TDerivedEntity Create() where TDerivedEntity : class, T
        {
            throw new NotImplementedException();
        }
 
        public T Create()
        {
            return Activator.CreateInstance<T>();
        }
 
        public virtual T Find(params object[] keyValues)
        {
            throw new NotImplementedException(
               "Derive from FakeDbSet and override Find");
        }
 
        public System.Collections.ObjectModel.ObservableCollection<T> Local
        {
            get 
               return new 
                 System.Collections.ObjectModel.ObservableCollection<T>(_data); 
            }
        }
 
        public T Remove(T entity)
        {
            _data.Remove(entity);
            return entity;
        }
 
        public IEnumerator<T> GetEnumerator()
        {
            return _data.GetEnumerator();
        }
 
        IEnumerator IEnumerable.GetEnumerator()
        {
            return _data.GetEnumerator();
        }
 
        public Type ElementType
        {
            get { return _query.ElementType; }
        }
 
        public Expression Expression
        {
            get { return _query.Expression; }
        }
 
        public IQueryProvider Provider
        {
            get { return _query.Provider; }
        }
    }
} 
 
 Step Two: Change the DbSet Properties of the DbContext to IDbSetsBecause DbSet implements IDbSet, we can safely change the DbSet properties of our DbContext to IDbSet.

Very Simple Example:
Before -
    public partial class ProductContext : DbContext
    {
        public ProductContext()
            : base("name=ProductContext")
        {
        }
   
        protected override void OnModelCreating(DbModelBuilder modelBuilder)
        {
            throw new UnintentionalCodeFirstException();
        }
   
        public DbSet ProductCodes { get; set; }
    }

After -
    public partial class ProductContext : DbContext
    {
        public ProductContext()
            : base("name=ProductContext")
        {
        }
   
        protected override void OnModelCreating(DbModelBuilder modelBuilder)
        {
            throw new UnintentionalCodeFirstException();
        }
   
        public IDbSet ProductCodes { get; set; }
    }

Important: You will also need to modify your DbContext's T4 template (.tt extension) so that when the DbContext is regenerated you won't lose your changes.

Look for the following:

    public string DbSet(EntitySet entitySet)
    {
        return string.Format(
            CultureInfo.InvariantCulture,
            "{0} DbSet<{1}> {2} {{ get; set; }}",
            Accessibility.ForReadOnlyProperty(entitySet),
            _typeMapper.GetTypeName(entitySet.ElementType),
            _code.Escape(entitySet));
    }

And change it to:

    public string DbSet(EntitySet entitySet)
    {
        return string.Format(
            CultureInfo.InvariantCulture,
            "{0} IDbSet<{1}> {2} {{ get; set; }}",
            Accessibility.ForReadOnlyProperty(entitySet),
            _typeMapper.GetTypeName(entitySet.ElementType),
            _code.Escape(entitySet));
    }

Step Three: Instantiate a FakeDbSet, Add Some Objects To It, And Assign to the Mock DbContext
Here's another very simple example. In this example, I have a very rudimentary unit test that uses the technique explained in this post. The unit test doesn't do too much: it creates a mock DbContext (using the Moq mocking framework), assigns a FakeDbSet to it, and passes the DbContext into a repository class which uses the DbContext to retrieve the data it needs to return. Then it confirms it got back what it was expecting. A super simple example, but hopefully useful to illustrate the techniques presented here.

using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MyProject.Products;
using Moq;
using MyProject.Models.Products;
using System.Data.Entity;
using System.Linq;
using System.Collections.Generic;
using System.Collections;
using System.Linq.Expressions;
using MyUnitTestProject.EFHelperClasses;

namespace MyUnitTestProject.Repositories
{
    [TestClass]
    public class ProductRepositoryTest
    {
        [TestMethod]
        public void GetProductCodesReturnSuccessfully()
        {
            Mock contextMock = new Mock<MyContext>();
            IDbSet productCodes = new FakeDbSet();
            productCodes.Add(new ProductCodes() { productCodeVal = "A", productCodeDescription = "A Description" });
            productCodes.Add(new ProductCodes() { productCodeVal = "B", productCodeDescription = "B Description" });

            contextMock.Object.ProductCodes = productCodes;

            using (ProductRepository repo = new ProductRepository(contextMock.Object))
            {
                var results = repo.GetProductCodes();
                Assert.IsNotNull(results);
                Assert.IsTrue(results.Count == 2);
            }
        }

    }
}

Comments

Unknown said…
Alternatively, you can download my project on Codeplex. I have a working example that you can copy.
https://entityinterfacegenerator.codeplex.com/

It generates the interface files that you need for IoC and unit testing purposes.
Anonymous said…
On this line "IDbSet productCodes = new FakeDbSet();" I get a compiler error that says "Using the generic type 'NeoUnitTests.FakeDbSet' requires 1 type arguments"
Anonymous said…
Thanks for this post, I've spent the last couple hours trying to solve the problem of mocking my DbContext.
Here is my complete test in case anyone is curious (I'm using FakeItEasy for mocks and an abstract factory to get my DbContext in my ViewModel)

// Arrange
var contextFactory = A.Fake();
var db = A.Fake();
var models = new FakeDbSet();
var phxModel = A.Fake();
phxModel.Name = "Phoenix";

models.Add(phxModel);

db.Models = models;
A.CallTo(() => contextFactory.GetContext()).Returns(db);

var vm = new MainViewModel(contextFactory);

// Act
vm.Loaded();

// Assert
Assert.That(vm.Models[0].Name == "Phoenix");
Anonymous said…
The html parser ate my generic type arguments, here is how the code should look with [] instead of <>. IDatabase is an interface I extracted from my Entity Framework context class

// Arrange
var contextFactory = A.Fake[IContextFactory]();
var db = A.Fake[IDatabase]();
var models = new FakeDbSet[Model]();
var phxModel = A.Fake[Model]();
phxModel.Name = "Phoenix";

models.Add(phxModel);

db.Models = models;
A.CallTo(() => contextFactory.GetContext()).Returns(db);

var vm = new MainViewModel(contextFactory);

// Act
vm.Loaded();

// Assert
Assert.That(vm.Models[0].Name == "Phoenix");
cmshefler said…
Here's an implementation of FakeDbSet.Find() that I am using. It assumes all your entities have a single primary key property and that they are named [entity type name] + "ID".

public virtual T Find(params object[] keyValues)
{
foreach (var obj in _data)
{
var thisObjType = obj.GetType();
var thisObjTypeName = thisObjType.Name;
var thisObjIDFieldName = thisObjTypeName + "ID";

if (thisObjType.GetProperty(thisObjIDFieldName).GetValue(obj).Equals(keyValues[0]))
{
return obj;
}

}
return null;
}
cmshefler said…
Addendum to my last comment -- since we already have an ElementType property, we can just get that instead of asking the object to GetType().

public virtual T Find(params object[] keyValues)
{
var thisObjTypeName = ElementType.Name;
var thisObjIDFieldName = thisObjTypeName + "ID";
foreach (var obj in _data)
{
if (ElementType.GetProperty(thisObjIDFieldName).GetValue(obj).Equals(keyValues[0]))
{
return obj;
}
}
return null;
}
Anonymous said…
I got the compiler error: "Constraints are not allowed on non-generic declarations" because of the following line:
public TDerivedEntity Create() where TDerivedEntity : class, T

Using a type parameter, "Create<TDerivedEntity>()" instead of "Create()", fixed the issue:
public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, T

Apparently they were stripped by blogspot, so you need to use HTML entity references to display angle brackets:
&lt; will yield <
&gt; will yield >

Popular Posts

Resolving the "n timer(s) still in the queue" Error In Angular Unit Tests

How to Get Norton Security Suite Firewall to Allow Remote Desktop Connections in Windows

Silent Renew and the "login_required" Error When Using oidc-client

Fixing the "Please add a @Pipe/@Directive/@Component annotation" Error In An Angular App After Upgrading to webpack 4

How to Determine if a Column Exists in a DataReader