I want to be able use a type instead of T with Dbset
For example Dbset(myType) instead of DbSet
EF core 3.1 requires you to use DbSet and all the functionality is encapsulated in Extension methods.
Here's a solution I went with for EF 3.1
Hoping that since I've already been through the pain, I can share my solution for the next poor soul who has to do this.
Usage
var dbSet = new DbSetFacade(myType);
dynamic results = await dbSet.ToListAsync();
I'll just share ToList,Count for now to keep it brief!
public class DbSetFacade
{
private static ReadOnlyDictionary<EfQueryExtensionType, MethodInfo> _extensionMethodInfos = GetExtensionMethodInfos();
private readonly object _dbSet;
private readonly ReadOnlyDictionary<EfQueryExtensionType, MethodInfo> _genericInstanceMethodInfos;
public DbSetFacade(DbContext context, Type entityType)
{
if (context == null) throw new ArgumentNullException(nameof(context));
if (entityType == null) throw new ArgumentNullException(nameof(entityType));
_dbSet = GetDbSetInstance(context, entityType);
_genericInstanceMethodInfos = new ReadOnlyDictionary<EfQueryExtensionType, MethodInfo>(new Dictionary<EfQueryExtensionType, MethodInfo>
{
[EfQueryExtensionType.AsNoTracking] = _extensionMethodInfos[EfQueryExtensionType.AsNoTracking].MakeGenericMethod(entityType),
[EfQueryExtensionType.ToListAsync] = _extensionMethodInfos[EfQueryExtensionType.ToListAsync].MakeGenericMethod(entityType),
[EfQueryExtensionType.CountAsyncWithoutPredicate] = _extensionMethodInfos[EfQueryExtensionType.CountAsyncWithoutPredicate].MakeGenericMethod(entityType),
});
}
enum EfQueryExtensionType
{
AsNoTracking,
ToListAsync,
//FirstOrDefaultAsyncWithPredicate,
//FirstOrDefaultAsyncWithoutPredicate,
//ContainsAsync,
//CountAsyncWithPredicate,
CountAsyncWithoutPredicate
}
public Task<int> CountAsync(CancellationToken token = default) => _genericInstanceMethodInfos[EfQueryExtensionType.CountAsyncWithoutPredicate].Invoke(null, new object[] { _dbSet, token }) as Task<int>;
public async Task<dynamic> ToListAsync(CancellationToken token = default)
{
var noTrackingSet = _genericInstanceMethodInfos[EfQueryExtensionType.AsNoTracking].Invoke(null, new object[] { _dbSet });
var toList = _genericInstanceMethodInfos[EfQueryExtensionType.ToListAsync].Invoke(null, new object[] { noTrackingSet, token }) as Task;
await toList.ConfigureAwait(false);
return (object)((dynamic)toList).Result;
}
private static object GetDbSetInstance(DbContext context, Type type)
{
return typeof(DbContext).GetMethods()
.FirstOrDefault(p => p.Name == "Set" && p.ContainsGenericParameters)
.MakeGenericMethod(type)
.Invoke(context, null);
}
static ReadOnlyDictionary<EfQueryExtensionType, MethodInfo> GetExtensionMethodInfos()
{
var extensionMethods = typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo()
.GetMethods()
.Where(x => x.ContainsGenericParameters && x.GetParameters()[0].ParameterType.IsGenericType);
return new ReadOnlyDictionary<EfQueryExtensionType, MethodInfo>(new Dictionary<EfQueryExtensionType, MethodInfo>
{
[EfQueryExtensionType.AsNoTracking] = extensionMethods.FirstOrDefault(x => x.Name == "AsNoTracking"),
[EfQueryExtensionType.ToListAsync] = extensionMethods.FirstOrDefault(x => x.Name == "ToListAsync" && x.GetParameters().Count() == 2),
[EfQueryExtensionType.CountAsyncWithoutPredicate] = extensionMethods.FirstOrDefault(x => x.Name == "CountAsync" && x.GetParameters().Count() == 2)
});
}
}
Related
i am wondering if i can pass table name as Parameter in EF CORE.
This is static example:
var post = await factory.Posts.SingleOrDefaultAsync(x => x.Id == id);
i define table staticlly here as "factory.Posts"
but.. i would want to pass this dynamically in parameter.
This is example (not working)
public async Task Accept(int id,Type type)
{
if (type == typeof(PostEntity))
{
using var factory = _contextFactory.CreateDbContext();
var post = await factory.type.SingleOrDefaultAsync(x => x.Id == id);
post!.IsVerified = true;
await factory.SaveChangesAsync();
}
}
public async Task Accept(int id,Type type)
{
if (type == typeof(PostEntity))
{
using var factory = _contextFactory.CreateDbContext();
var post = await factory.Set<PostEntity>().SingleOrDefaultAsync(x => x.Id == id);
post!.IsVerified = true;
await factory.SaveChangesAsync();
}
}
Or dynamically by using the Find method :
public virtual object Find([NotNull] Type entityType, [CanBeNull] params object[] keyValues)
so in the case :
public async Task Accept(int id,Type type)
{
using var factory = _contextFactory.CreateDbContext();
dynamic post = contextFactory().Find(type,id)
post!.IsVerified = true;
await factory.SaveChangesAsync();
}
The best solution is :
Create a base class BaseEntity (conatins IsVerified , id )
Configure it in the context
use it to retrieve your item
I try to get runtime method infoes in static class. I have four static method inside the class and each name is equal, also parameter name is equal. The only difference is their types. One of four method has string parameter so it is easy to get method info. However the others don't working. I find several advice but that are not working.
All test code is here.
class Program {
static void Main(string[] args) {
//ok
var stringMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(string) });
//not working
var dictMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(Dictionary<,>) });
//not working
var genericMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(object) });
//not working
var listMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(List<>) });
//not working
var res = typeof(TestClass)
.GetRuntimeMethods()
.Where(x => x.Name.Equals("TestMethod"))
.Select(m => new { Method = m, Parameters = m.GetParameters() })
.FirstOrDefault(p =>
p.Parameters.Length == 1
&& p.Parameters[0].ParameterType.IsGenericType
&& p.Parameters[0].ParameterType.GetGenericTypeDefinition() == typeof(ICollection<>)
);
}
}
public static class TestClass {
public static bool TestMethod(string item) {
return true;
}
public static bool TestMethod<TKey, TValue>(Dictionary<TKey, TValue> item) {
return true;
}
public static bool TestMethod<T>(T item) {
return true;
}
public static bool TestMethod<T>(List<T> item) {
return true;
}
}
If you are using .net core 2.1 or greater, you can use Type.MakeGenericMethodParameter to let you refer to a generic parameter of a method. You can use that to create a generic type argument that will work with GetMethod (not available for GetRuntimeMethod).
var stringMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(string) });
Type[] dictionaryTypeParameters = { typeof(Dictionary<,>).MakeGenericType(Type.MakeGenericMethodParameter(0), Type.MakeGenericMethodParameter(1)) };
MethodInfo dictMethodInfo = typeof(TestClass).GetMethod("TestMethod", 2, dictionaryTypeParameters);
MethodInfo listMethodInfo = typeof(TestClass).GetMethod("TestMethod", 1, new[] { typeof(List<>).MakeGenericType(Type.MakeGenericMethodParameter(0)) });
MethodInfo genericMethodInfo = typeof(TestClass).GetMethod("TestMethod", 1, new[] { Type.MakeGenericMethodParameter(0) });
Some interesting reading on the topic here.
Let's say we would like to use this method to get any MethodInfo for the various TestMethod. Note that they all have exactly one parameter, so p.Parameters.Length == 1 is useless:
Defined as bool TestMethod(string item). We can use
.FirstOrDefault(p => p.Method.IsGenericMethod)
Defined as bool TestMethod<TKey, TValue>(Dictionary<TKey, TValue> item)
.FirstOrDefault(p =>
p.Method.IsGenericMethod &&
p.Method.GetGenericArguments().Length == 2)
Defined as bool TestMethod<T>(T item)
.FirstOrDefault(p =>
p.Method.IsGenericMethod &&
p.Parameters[0].ParameterType == m.Method.GetGenericArguments()[0]
)
Defined as TestMethod<T>(List<T> item)
.FirstOrDefault(p =>
p.Method.IsGenericMethod &&
p.Parameters[0].ParameterType.GetGenericTypeDefinition() == typeof(List<>)
)
In case of generic methods you have to query the MethodInfo object to get the appropriate method.
You can do it as below -
var dictMethodInfo = typeof(TestClass).GetMethods().Single(m => m.Name == "TestMethod" && m.IsGenericMethod &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.IsGenericType &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(Dictionary<,>));
In your case, getting the MethodInfo for TestMethod<T> is bit tricky, but below should work -
var genericMethodInfo = typeof(TestClass).GetMethods().Single(m => m.Name == "TestMethod" && m.IsGenericMethod &&
!m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.IsGenericType);
Final Code -
var stringMethodInfo = typeof(TestClass).GetRuntimeMethod("TestMethod", new[] { typeof(string) });
var dictMethodInfo = typeof(TestClass).GetMethods().Single(m => m.Name == "TestMethod" && m.IsGenericMethod &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.IsGenericType &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(Dictionary<,>));
var genericMethodInfo = typeof(TestClass).GetMethods().Single(m => m.Name == "TestMethod" && m.IsGenericMethod &&
!m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.IsGenericType);
var listMethodInfo = typeof(TestClass).GetMethods().Single(m => m.Name == "TestMethod" && m.IsGenericMethod &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.IsGenericType &&
m.GetGenericMethodDefinition().GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(List<>));
I wanted to log changes of certain entities (marked with attribute) so, I created AbstractSessionInterceptor's descendant to get access to entity changes. Also I want to know who did this changes, so I need to access current user, so through IWorkContextAccessor I'm creating IWorkContextScope, getting WorkContext and trying to get user, when existing entity is being edited I'm able to access current user, when new entity is created with contentmanager I'm getting timeout exception.. Then I getting WorkContext via IWorkContextAccessor.GetContext() I'm get infinite loop (interceptor is being called again and again). Any Ideas and suggestion would be appreciated.
Thanks.
Source:
public class AccountInterceptor : AbstractSessionInterceptor
{
private IProtocolLogger _logger;
private readonly Type _mainAttrType = typeof(ProtocolAttribute);
private readonly Type _fieldAttrType = typeof(ProtocolFieldAttribute);
private readonly IWorkContextAccessor _contextAccessor;
ISessionFactoryHolder _sessionFactoryHolder;
public AccountInterceptor(IWorkContextAccessor contextAccessor, ISessionFactoryHolder sessionFactoryHolder)
{
_contextAccessor = contextAccessor;
_sessionFactoryHolder = sessionFactoryHolder;
}
public override bool OnFlushDirty(object entity, object id, object[] currentState, object[] previousState, string[] propertyNames, NHibernate.Type.IType[] types)
{
var t = entity.GetType();
var attributes = t.GetCustomAttributes(_mainAttrType, true);
if (attributes.Length != 0)
{
IWorkContextScope scope = _contextAccessor.CreateWorkContextScope();
WorkContext context = scope.WorkContext;
if (context != null)
{
var attr = (ProtocolAttribute)attributes.FirstOrDefault();
var currentDic = currentState.Select((s, i) => new { S = s, Index = i }).ToDictionary(x => x.Index, x => x.S);
var prvDic = previousState.Select((s, i) => new { S = s, Index = i }).ToDictionary(x => x.Index, x => x.S);
var diff = compare(currentDic, prvDic);
if (!attr.LogAllData)
{
List<string> properties = new List<string>();
foreach (var propety in t.GetProperties())
{
var propertyAttributes = propety.GetCustomAttributes(_fieldAttrType, true);
if (propertyAttributes.Length != 0)
properties.Add(propety.Name);
}
if (properties.Count != 0)
{
var necesseryProps = propertyNames.Select((s, i) => new { S = s, Index = i }).Where(p => properties.Contains(p.S)).ToDictionary(x => x.Index, x => x.S);
TupleList<int, object, object> ToRemove = new TupleList<int, object, object>();
foreach (var tuple in diff)
{
if (!necesseryProps.Keys.Contains(tuple.Item1))
{
ToRemove.Add(tuple);
}
}
ToRemove.ForEach(d => diff.Remove(d));
}
}
if (diff.Count != 0)
{
_logger = ProtocolLogger.GetInstance();
var sessionFactory = _sessionFactoryHolder.GetSessionFactory();
var session = sessionFactory.OpenSession();
var user = GetCurrentUser(session, context.HttpContext);
string propertiesFormat = GetPropertiesStringFormat(diff, propertyNames);
object[] param = new object[] { DateTime.Now, entity, propertyNames };
string entityId = string.Empty;
try
{
if (entity is IAuditable)
{
entityId = ((IAuditable)entity).Id.ToString();
}
}
catch (Exception)
{
entityId = entity.ToString();
}
foreach (var pair in diff)
{
ProtocolPropertyInfo info = new ProtocolPropertyInfo(propertyNames[pair.Item1], Convert.ToString(pair.Item2), Convert.ToString(pair.Item3));
_logger.Log(user, entity, entityId, session, context, Operation.Write, info);
}
session.Flush();
session.Close();
}
}
}
return base.OnFlushDirty(entity, id, currentState, previousState, propertyNames, types);
}
private object GetCurrentUser(ISession session, HttpContextBase httpContext)
{
if (httpContext == null || !httpContext.Request.IsAuthenticated || !(httpContext.User.Identity is FormsIdentity))
{
return null;
}
var formsIdentity = (FormsIdentity)httpContext.User.Identity;
var userData = formsIdentity.Ticket.UserData ?? "";
// the cookie user data is {userId};{tenant}
var userDataSegments = userData.Split(';');
if (userDataSegments.Length != 2)
{
return null;
}
var userDataId = userDataSegments[0];
var userDataTenant = userDataSegments[1];
int userId;
if (!int.TryParse(userDataId, out userId))
{
return null;
}
Type regType = Assembly.Load("Orchard.Users").GetTypes().First(t => t.Name == "UserPartRecord");
var user = session.Get(regType, userId);
return user;
}
private string GetPropertiesStringFormat(TupleList<int, object, object> diffDic, string[] propertyNames)
{
StringBuilder result = new StringBuilder();
foreach (var pair in diffDic)
{
result.AppendFormat("Property name {0}, New value {1}, Old value {2}", propertyNames[pair.Item1], pair.Item2, pair.Item3);
}
return result.ToString();
}
private TupleList<int, object, object> compare(Dictionary<int, object> dic1, Dictionary<int, object> dic2)
{
var diff = new TupleList<int, object, object>();
foreach (KeyValuePair<int, object> pair in dic1)
{
if (!Equals(pair.Value, dic2[pair.Key]))
{
diff.Add(pair.Key, pair.Value, dic2[pair.Key]);
}
}
return diff;
}
}
Never ever start new work context inside your interceptor - infinite loop is guaranteed. And in fact, you don't have to. Each interceptor is already instantiated per work context, so you can inject dependencies via ctor, as usual.
To access current user you can either:
inject IOrchardServices and use .WorkContext.CurrentUser property, or
or use contextAccessor.GetContext() to get context and then call CurrentUser on it.
Also, be careful with performing database operations from inside the interceptor as those will most likely result in infinite loops and throw stack overflow exceptions.
I am trying to apply the new mocking of EF6 to my existing code.
I have a class that Extends DbSet. One of the methods call the base class (BdSet) Create method. Here is the a sample code (not the complete solution or real names):
public class DerivedDbSet<TEntity> : DbSet<TEntity>, IKeyValueDbSet<TEntity>, IOrderedQueryable<TEntity> where TEntity : class
{
public virtual bool Add(string value1, string value2) {
var entity = Create(); // There is no direct implementation of the Create method it is calling the base method
// Do something with the values
this.Add(entity);
return true;
}
}
I am mocking using the Test Doubles sample (here is the peace of code):
var data = new List<DummyEntity> {
new DummyEntity { Value1 = "First", Value2 = "001" },
new DummyEntity { Value1 = "Second", Value2 = "002" }
}.AsQueryable();
var mock = new Mock<DerivedDbSet<DummyEntity>>();
mock.CallBase = true;
mock.As<IQueryable<T>>().Setup(m => m.Provider).Returns(source.Provider);
mock.As<IQueryable<T>>().Setup(m => m.Expression).Returns(source.Expression);
mock.As<IQueryable<T>>().Setup(m => m.ElementType).Returns(source.ElementType);
mock.As<IQueryable<T>>().Setup(m => m.GetEnumerator()).Returns(source.GetEnumerator());
I've set the CallBase property to true to try to force the call to the base class...
But I keep receiving the following error:
System.NotImplementedException: The member 'Create' has not been implemented on type 'DerivedDbSet1Proxy' which inherits from 'DbSet1'. Test doubles for 'DbSet`1' must provide implementations of methods and properties that are used.
I want the call of create to fallback to the default implementation in DbSet.
Can someone help me with that?
After some struggle with the internal functions and async references of mocking a DbSet I came out with a helper class that solved most of my problems and might serve as base to someone implementation.
Here is the code:
public static class MockHelper
{
internal class TestDbAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider {
private readonly IQueryProvider _inner;
internal TestDbAsyncQueryProvider(IQueryProvider inner) { _inner = inner; }
public IQueryable CreateQuery(Expression expression) { return new TestDbAsyncEnumerable<TEntity>(expression); }
public IQueryable<TElement> CreateQuery<TElement>(Expression expression) { return new TestDbAsyncEnumerable<TElement>(expression); }
public object Execute(Expression expression) { return _inner.Execute(expression); }
public TResult Execute<TResult>(Expression expression) { return _inner.Execute<TResult>(expression); }
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken) { return Task.FromResult(Execute(expression)); }
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken) { return Task.FromResult(Execute<TResult>(expression)); }
}
internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T> {
public TestDbAsyncEnumerable(IEnumerable<T> enumerable) : base(enumerable) { }
public TestDbAsyncEnumerable(Expression expression) : base(expression) { }
public IDbAsyncEnumerator<T> GetAsyncEnumerator() { return new TestDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator()); }
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator() { return GetAsyncEnumerator(); }
public IQueryProvider Provider { get { return new TestDbAsyncQueryProvider<T>(this); } }
}
internal class TestDbAsyncEnumerator<T> : IDbAsyncEnumerator<T> {
private readonly IEnumerator<T> _inner;
public TestDbAsyncEnumerator(IEnumerator<T> inner) { _inner = inner; }
public void Dispose() { _inner.Dispose(); }
public Task<bool> MoveNextAsync(CancellationToken cancellationToken) { return Task.FromResult(_inner.MoveNext()); }
public T Current { get { return _inner.Current; } }
object IDbAsyncEnumerator.Current { get { return Current; } }
}
public static Mock<TDbSet> CreateDbSet<TDbSet, TEntity>(IList<TEntity> data, Func<object[], TEntity> find = null)
where TDbSet : class, IDbSet<TEntity>
where TEntity : class, new() {
var source = data.AsQueryable();
var mock = new Mock<TDbSet> { CallBase = true };
mock.As<IQueryable<TEntity>>().Setup(m => m.Expression).Returns(source.Expression);
mock.As<IQueryable<TEntity>>().Setup(m => m.ElementType).Returns(source.ElementType);
mock.As<IQueryable<TEntity>>().Setup(m => m.GetEnumerator()).Returns(source.GetEnumerator());
mock.As<IQueryable<TEntity>>().Setup(m => m.Provider).Returns(new TestDbAsyncQueryProvider<TEntity>(source.Provider));
mock.As<IDbAsyncEnumerable<TEntity>>().Setup(m => m.GetAsyncEnumerator()).Returns(new TestDbAsyncEnumerator<TEntity>(data.GetEnumerator()));
mock.As<IDbSet<TEntity>>().Setup(m => m.Create()).Returns(new TEntity());
mock.As<IDbSet<TEntity>>().Setup(m => m.Add(It.IsAny<TEntity>())).Returns<TEntity>(i => { data.Add(i); return i; });
mock.As<IDbSet<TEntity>>().Setup(m => m.Remove(It.IsAny<TEntity>())).Returns<TEntity>(i => { data.Remove(i); return i; });
if (find != null) mock.As<IDbSet<TEntity>>().Setup(m => m.Find(It.IsAny<object[]>())).Returns(find);
return mock;
}
public static Mock<DbSet<TEntity>> CreateDbSet<TEntity>(IList<TEntity> data, Func<object[], TEntity> find = null)
where TEntity : class, new() {
return CreateDbSet<DbSet<TEntity>, TEntity>(data, find);
}
}
And here is a use sample (based on the names I've given before):
var data = new List<DummyEntity> {
new DummyEntity { Value1 = "First", Value2 = "001" } },
new DummyEntity { Value1 = "Second", Value2 = "002" } }
};
var mockDummyEntities = MockHelper.CreateDbSet<DerivedDbSet<DummyEntities>, DummyEntities>(data, i => data.FirstOrDefault(k => k.Value2 == (string)i[0]));
var mockContext = new Mock<DummyDbContext>();
mockContext.Setup(c => c.DummyEntities).Returns(mockDummyEntities.Object);
Any suggestions on how to improve this solution is very welcome.
Regards
This is a modified version based on Andre that worked for me. Note, that I did not need the async references. Code will add all derived classes (if any)
Usage:
/// <summary>
///
/// </summary>
[TestMethod]
public void SomeTest()
{
//Setup
var mockContext = new Mock<FakeDbContext>();
//SomeClass can be abstract or concrete
mockContext.createFakeDBSet<SomeClass>();
var db = mockContext.Object;
//Setup create(s) if needed on concrete classes
//Mock.Get(db.Set<SomeOtherClass>()).Setup(x => x.Create()).Returns(new SomeOtherClass());
//DO Stuff
var list1 = db.Set<SomeClass>().ToList();
//SomeOtherClass derived from SomeClass
var subList1 = db.Set<SomeOtherClass>().ToList();
CollectionAssert.AreEquivalent(list1.OfType<SomeOtherClass>.ToList(), subList1);
}
Code:
/// <summary>
/// http://stackoverflow.com/questions/21943328/ef6-mocking-derived-dbsets
/// </summary>
public static class MoqSetupExtensions
{
static IEnumerable<Type> domainTypes = AppDomain.CurrentDomain.GetAssemblies().SelectMany(x => x.GetTypes());
public static Mock<DbSet<T>> createFakeDBSet<T>(this Mock<FakeDbContext> db, List<T> list = null, Func<List<T>, object[], T> find = null, bool createDerivedSets = true) where T : class
{
list = list ?? new List<T>();
var data = list.AsQueryable();
//var mockSet = MockHelper.CreateDbSet(list, find);
var mockSet = new Mock<DbSet<T>>() { CallBase = true };
mockSet.As<IQueryable<T>>().Setup(m => m.Provider).Returns(() => { return data.Provider; });
mockSet.As<IQueryable<T>>().Setup(m => m.Expression).Returns(() => { return data.Expression; });
mockSet.As<IQueryable<T>>().Setup(m => m.ElementType).Returns(() => { return data.ElementType; });
mockSet.As<IQueryable<T>>().Setup(m => m.GetEnumerator()).Returns(() => { return list.GetEnumerator(); });
mockSet.Setup(m => m.Add(It.IsAny<T>())).Returns<T>(i => { list.Add(i); return i; });
mockSet.Setup(m => m.AddRange(It.IsAny<IEnumerable<T>>())).Returns<IEnumerable<T>>((i) => { list.AddRange(i); return i; });
mockSet.Setup(m => m.Remove(It.IsAny<T>())).Returns<T>(i => { list.Remove(i); return i; });
if (find != null) mockSet.As<IDbSet<T>>().Setup(m => m.Find(It.IsAny<object[]>())).Returns<object[]>((i) => { return find(list, i); });
//mockSet.Setup(m => m.Create()).Returns(new T());
db.Setup(x => x.Set<T>()).Returns(mockSet.Object);
//Setup all derived classes
if (createDerivedSets)
{
var type = typeof(T);
var concreteTypes = domainTypes.Where(x => type.IsAssignableFrom(x) && type != x).ToList();
var method = typeof(MoqSetupExtensions).GetMethod("createFakeDBSetSubType");
foreach (var item in concreteTypes)
{
var invokeResult = method.MakeGenericMethod(type, item)
.Invoke(null, new object[] { db, mockSet });
}
}
return mockSet;
}
public static Mock<DbSet<SubType>> createFakeDBSetSubType<BaseType, SubType>(this Mock<FakeDbContext> db, Mock<DbSet<BaseType>> baseSet)
where BaseType : class
where SubType : class, BaseType
{
var dbSet = db.Object.Set<BaseType>();
var mockSet = new Mock<DbSet<SubType>>() { CallBase = true };
mockSet.As<IQueryable<SubType>>().Setup(m => m.Provider).Returns(() => { return dbSet.OfType<SubType>().Provider; });
mockSet.As<IQueryable<SubType>>().Setup(m => m.Expression).Returns(() => { return dbSet.OfType<SubType>().Expression; });
mockSet.As<IQueryable<SubType>>().Setup(m => m.ElementType).Returns(() => { return dbSet.OfType<SubType>().ElementType; });
mockSet.As<IQueryable<SubType>>().Setup(m => m.GetEnumerator()).Returns(() => { return dbSet.OfType<SubType>().GetEnumerator(); });
mockSet.Setup(m => m.Add(It.IsAny<SubType>())).Returns<SubType>(i => { dbSet.Add(i); return i; });
mockSet.Setup(m => m.AddRange(It.IsAny<IEnumerable<SubType>>())).Returns<IEnumerable<SubType>>((i) => { dbSet.AddRange(i); return i; });
mockSet.Setup(m => m.Remove(It.IsAny<SubType>())).Returns<SubType>(i => { dbSet.Remove(i); return i; });
mockSet.As<IDbSet<SubType>>().Setup(m => m.Find(It.IsAny<object[]>())).Returns<object[]>((i) => { return dbSet.Find(i) as SubType; });
baseSet.Setup(m => m.Create<SubType>()).Returns(() => { return mockSet.Object.Create(); });
db.Setup(x => x.Set<SubType>()).Returns(mockSet.Object);
return mockSet;
}
}
Microsoft published a very well explained guide providing an in-memory implementation of DbSet which has a complete implemention for all the methods on DbSet.
Check the article at http://msdn.microsoft.com/en-us/data/dn314431.aspx#doubles
I'm trying to add a LINQ or DbContext extension method to get an element (FirstOrDefault) but if one does not already exist then create a new instance with data (FirstOrCreate) instead of returning null.
is this possible?
i.e.:
public static class LINQExtension
{
public static TSource FirstOrCreate<TSource>(
this IEnumerable<TSource> source,
Func<TSource, bool> predicate)
{
if (source.First(predicate) != null)
{
return source.First(predicate);
}
else
{
return // ???
}
}
}
and a usage could be:
using (var db = new MsBoxContext())
{
var status = db.EntitiesStatus.FirstOrCreate(s => s.Name == "Enabled");
//Here we should get the object if we find one
//and if it doesn't exist create and return a new instance
db.Entities.Add(new Entity()
{
Name = "New Entity",
Status = status
});
}
I hope that you understand my approach.
public static class LINQExtension
{
public static TSource FirstOrCreate<TSource>(
this IQueryable<TSource> source,
Expression<Func<TSource, bool>> predicate,
Func<T> defaultValue)
{
return source.FirstOrDefault(predicate) ?? defaultValue();
}
}
usage
var status = db.EntitiesStatus.FirstOrCreate(s => s.Name == "Enabled",
() => new EntityStatus {Name = "Enabled"});
However you must note that this will not work quite like FirstOrDefault().
If you did the following
var listOfStuff = new List<string>() { "Enabled" };
var statuses = from s in listOfStuff
select db.EntitiesStatus.FirstOrCreate(s => s.Name == "Enabled",
() => new EntityStatus {Name = "Enabled"});
You would get O(n) hits to the database.
However I suspect if you did...
var listOfStuff = new List<string>() { "Enabled" };
var statuses = from s in listOfStuff
select db.EntitiesStatus.FirstOrDefault(s => s.Name == "Enabled")
?? new EntityStatus {Name = "Enabled"};
It is plausible it could work...
conclussion:
instead of implement an extension method the best solution is using ?? operator in that way:
var status = db.EntitiesStatus.FirstOrDefault(s => s.Name == "Enabled") ?? new EntityStatus(){Name = "Enabled"};
I am a self taught programmer and I am really bad at typing so i was looking for the exact same thing. I ended up writing my own. It took a few steps and revisions before it would work with more than 1 property. Of course there are some limitations and I haven't fully tested it but so far it seems to work for my purposes of keeping the records distinct in the DB and shortening the code (typing time).
public static class DataExtensions
{
public static TEntity InsertIfNotExists<TEntity>(this ObjectSet<TEntity> objectSet, Expression<Func<TEntity, bool>> predicate) where TEntity : class, new()
{
TEntity entity;
#region Check DB
entity = objectSet.FirstOrDefault(predicate);
if (entity != null)
return entity;
#endregion
//NOT in the Database... Check Local cotext so we do not enter duplicates
#region Check Local Context
entity = objectSet.Local().AsQueryable().FirstOrDefault(predicate);
if (entity != null)
return entity;
#endregion
///********* Does NOT exist create entity *********\\\
entity = new TEntity();
// Parse Expression Tree and set properties
//Hit a recurrsive function to get all the properties and values
var body = (BinaryExpression)((LambdaExpression)predicate).Body;
var dict = body.GetDictionary();
//Set Values on the new entity
foreach (var item in dict)
{
entity.GetType().GetProperty(item.Key).SetValue(entity, item.Value);
}
return entity;
}
public static Dictionary<string, object> GetDictionary(this BinaryExpression exp)
{
//Recurssive function that creates a dictionary of the properties and values from the lambda expression
var result = new Dictionary<string, object>();
if (exp.NodeType == ExpressionType.AndAlso)
{
result.Merge(GetDictionary((BinaryExpression)exp.Left));
result.Merge(GetDictionary((BinaryExpression)exp.Right));
}
else
{
result[((MemberExpression)exp.Left).Member.Name] = exp.Right.GetExpressionVaule();
}
return result;
}
public static object GetExpressionVaule(this Expression exp)
{
if (exp.NodeType == ExpressionType.Constant)
return ((ConstantExpression)exp).Value;
if (exp.Type.IsValueType)
exp = Expression.Convert(exp, typeof(object));
//Taken From http://stackoverflow.com/questions/238413/lambda-expression-tree-parsing
var accessorExpression = Expression.Lambda<Func<object>>(exp);
Func<object> accessor = accessorExpression.Compile();
return accessor();
}
public static IEnumerable<T> Local<T>(this ObjectSet<T> objectSet) where T : class
{
//Taken From http://blogs.msdn.com/b/dsimmons/archive/2009/02/21/local-queries.aspx?Redirected=true
return from stateEntry in objectSet.Context.ObjectStateManager.GetObjectStateEntries(
EntityState.Added |
EntityState.Modified |
EntityState.Unchanged)
where stateEntry.Entity != null && stateEntry.EntitySet == objectSet.EntitySet
select stateEntry.Entity as T;
}
public static void Merge<TKey, TValue>(this Dictionary<TKey, TValue> me, Dictionary<TKey, TValue> merge)
{
//Taken From http://stackoverflow.com/questions/4015204/c-sharp-merging-2-dictionaries
foreach (var item in merge)
{
me[item.Key] = item.Value;
}
}
}
Usage is as simple as:
var status = db.EntitiesStatus.InsertIfNotExists(s => s.Name == "Enabled");
The extension will check the database first, if is not found it will check the local context (so you do not add it twice), if it is still not found it creates the entity, parses the expression tree to get the properties and values from the lambda expression, sets those values on a new entity, adds the entity to the context and returns the new entity.
A few things to be aware of...
This does not handle all possible uses (assumes all the expressions in the lambda are ==)
The project I did this in is using an ObjectContext as apposed to a DBContext (I have not switched yet so I don't know if this would work with DBContext. I assume it would not be difficult to change)
I am self-taught so there maybe many ways to optimize this. If you have any input please let me know.
What about this extension that also adds the new created entity to the DbSet.
public static class DbSetExtensions
{
public static TEntity FirstOrCreate<TEntity>(
this DbSet<TEntity> dbSet,
Expression<Func<TEntity, bool>> predicate,
Func<TEntity> defaultValue)
where TEntity : class
{
var result = predicate != null
? dbSet.FirstOrDefault(predicate)
: dbSet.FirstOrDefault();
if (result == null)
{
result = defaultValue?.Invoke();
if (result != null)
dbSet.Add(result);
}
return result;
}
public static TEntity FirstOrCreate<TEntity>(
this DbSet<TEntity> dbSet,
Func<TEntity> defaultValue)
where TEntity : class
{
return dbSet.FirstOrCreate(null, defaultValue);
}
}
The usage with predicate:
var adminUser = DbContext.Users.FirstOrCreate(u => u.Name == "Admin", () => new User { Name = "Admin" });
or without predicate:
var adminUser = DbContext.Users.FirstOrCreate(() => new User { Name = "Admin" });