I have an ExpressionVisitor which I add to EF Core's IQueryable<T>. Everything works fine except the Include methods. Probably because they enforce your IQueryable<T>.Provider to be an EntityQueryProvider.
Whenever I try to Include now it results in multiple queries which in turn results in the error "A second operation started on this context before a previous operation completed. Any instance members are not guaranteed to be thread safe.".
How can I wire up my ExpressionVisitor so it still works with EF Core's Include functionality?
My issue is similar to this one except for EF Core instead of EF.
I hook up my ExpressionVisitor by calling it on the DbSet:
return new Translator<TEntity>(
_dbSet
.AsNoTracking());
This is my Translator class:
public class Translator<T> : IOrderedQueryable<T>
{
private readonly Expression _expression;
private readonly TranslatorProvider<T> _provider;
public Translator(IQueryable source)
{
_expression = Expression.Constant(this);
_provider = new TranslatorProvider<T>(source);
}
public Translator(IQueryable source, Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
_expression = expression;
_provider = new TranslatorProvider<T>(source);
}
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>)_provider.ExecuteEnumerable(_expression)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return _provider.ExecuteEnumerable(_expression).GetEnumerator();
}
public Type ElementType => typeof(T);
public Expression Expression => _expression;
public IQueryProvider Provider => _provider;
}
And this is my TranslatorProvider<T> class (I've taken out the non-relevant Visit methods to shorten the post):
public class TranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
private readonly IQueryable _source;
public TranslatorProvider(IQueryable source)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
_source = source;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
return new Translator<TElement>(_source, expression);
}
public IQueryable CreateQuery(Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
var elementType = expression.Type.GetGenericArguments().First();
var result = (IQueryable) Activator.CreateInstance(typeof(Translator<>).MakeGenericType(elementType),
_source, expression);
return result;
}
public TResult Execute<TResult>(Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
var result = (this as IQueryProvider).Execute(expression);
return (TResult) result;
}
public object Execute(Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
var translated = Visit(expression);
return _source.Provider.Execute(translated);
}
internal IEnumerable ExecuteEnumerable(Expression expression)
{
if (expression == null)
{
throw new ArgumentNullException(nameof(expression));
}
var translated = Visit(expression);
return _source.Provider.CreateQuery(translated);
}
protected override Expression VisitConstant(ConstantExpression node)
{
if (node.Type == typeof(Translator<T>))
{
return _source.Expression;
}
else
{
return base.VisitConstant(node);
}
}
}
Update (EF Core 3.x):
The internal query pipeline infrastructure has changed. The new query expression preprocessing extension point is QueryTranslationPreprocessor class - Process method. Plugging it in requires replacing the IQueryTranslationPreprocessorFactory. e.g.
using System.Linq.Expressions;
namespace Microsoft.EntityFrameworkCore.Query
{
public class CustomQueryTranslationPreprocessor : RelationalQueryTranslationPreprocessor
{
public CustomQueryTranslationPreprocessor(QueryTranslationPreprocessorDependencies dependencies, RelationalQueryTranslationPreprocessorDependencies relationalDependencies, QueryCompilationContext queryCompilationContext)
: base(dependencies, relationalDependencies, queryCompilationContext) { }
public override Expression Process(Expression query) => base.Process(Preprocess(query));
private Expression Preprocess(Expression query)
{
// query = new YourExpressionVisitor().Visit(query);
return query;
}
}
public class CustomQueryTranslationPreprocessorFactory : IQueryTranslationPreprocessorFactory
{
public CustomQueryTranslationPreprocessorFactory(QueryTranslationPreprocessorDependencies dependencies, RelationalQueryTranslationPreprocessorDependencies relationalDependencies)
{
Dependencies = dependencies;
RelationalDependencies = relationalDependencies;
}
protected QueryTranslationPreprocessorDependencies Dependencies { get; }
protected RelationalQueryTranslationPreprocessorDependencies RelationalDependencies;
public QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext)
=> new CustomQueryTranslationPreprocessor(Dependencies, RelationalDependencies, queryCompilationContext);
}
}
and
optionsBuilder.ReplaceService<IQueryTranslationPreprocessorFactory, CustomQueryTranslationPreprocessorFactory>();
Original:
Apparently custom query providers don't fit in the current EF Core queryable pipeline, since several methods (Include, AsNoTracking etc.) require provider to be EntityQueryProvider.
At the time of writing (EF Core 2.1.2), the query translation process involves several services - IAsyncQueryProvider, IQueryCompiler, IQueryModelGenerator and more. All they are replaceable, but the easiest place for interception I see is the IQueryModelGenerator service - ParseQuery method.
So, forget about custom IQueryable / IQueryProvider implementation, use the following class and plug your expression visitor inside Preprocess method:
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Remotion.Linq;
using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation;
class CustomQueryModelGenerator : QueryModelGenerator
{
public CustomQueryModelGenerator(INodeTypeProviderFactory nodeTypeProviderFactory, IEvaluatableExpressionFilter evaluatableExpressionFilter, ICurrentDbContext currentDbContext)
: base(nodeTypeProviderFactory, evaluatableExpressionFilter, currentDbContext)
{ }
public override QueryModel ParseQuery(Expression query) => base.ParseQuery(Preprocess(query));
private Expression Preprocess(Expression query)
{
// return new YourExpressionVisitor().Visit(query);
return query;
}
}
and replace the corresponding EF Core service inside your derived context OnConfiguring override:
optionsBuilder.ReplaceService<IQueryModelGenerator, CustomQueryModelGenerator>();
The drawback is that this is using EF Core "internal" stuff, so you should keep monitoring for changes in the future updates.
I'm trying to create a wrapper around QueryableBase and INhQueryProvider that would receive a collection in the constructor and query it in-memory instead of going to a database. This is so I can mock the behavior of NHibernate's ToFuture() and properly unit test my classes.
The problem is that I'm facing a stack overflow due to infinite recursion and I'm struggling to find the reason.
Here's my implementation:
public class NHibernateQueryableProxy<T> : QueryableBase<T>, IOrderedQueryable<T>
{
public NHibernateQueryableProxy(IQueryable<T> data) : base(new NhQueryProviderProxy<T>(data))
{
}
public NHibernateQueryableProxy(IQueryParser queryParser, IQueryExecutor executor) : base(queryParser, executor)
{
}
public NHibernateQueryableProxy(IQueryProvider provider) : base(provider)
{
}
public NHibernateQueryableProxy(IQueryProvider provider, Expression expression) : base(provider, expression)
{
}
public new IEnumerator<T> GetEnumerator()
{
return Provider.Execute<IEnumerable<T>>(Expression).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
internal class NhQueryProviderProxy<T> : INhQueryProvider
{
private readonly IQueryProvider provider;
public NhQueryProviderProxy(IQueryable<T> data)
{
provider = data.AsQueryable().Provider;
}
public IQueryable CreateQuery(Expression expression)
{
return new NHibernateQueryableProxy<T>(this, expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new NHibernateQueryableProxy<TElement>(this, expression);
}
public object Execute(Expression expression)
{
return provider.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return provider.Execute<TResult>(expression);
}
public object ExecuteFuture(Expression expression)
{
return provider.Execute(expression);
}
public void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary<string, Tuple<object, IType>> parameters)
{
throw new NotImplementedException();
}
}
Edit: I've kind of figured out the problem. One of the arguments to expression is my custom queryable. When this expression is executed by the provider, it causes an infinite call loop between CreateQuery and Execute. Is it possible to change all the references to my custom queryable to the queryable wrapped by this class?
After a while I decided to give it another try and I guess I've managed to mock it. I didn't test it with real case scenarios but I don't think many tweaks will be necessary. Most of this code is either taken from or based on this tutorial. There are some caveats related to IEnumerable when dealing with those queries.
We need to implement QueryableBase since NHibernate asserts the type when using ToFuture.
public class NHibernateQueryableProxy<T> : QueryableBase<T>
{
public NHibernateQueryableProxy(IQueryable<T> data) : base(new NhQueryProviderProxy<T>(data))
{
}
public NHibernateQueryableProxy(IQueryProvider provider, Expression expression) : base(provider, expression)
{
}
}
Now we need to mock a QueryProvider since that's what LINQ queries depend on and it needs to implement INhQueryProvider because ToFuture() also uses it.
public class NhQueryProviderProxy<T> : INhQueryProvider
{
private readonly IQueryable<T> _data;
public NhQueryProviderProxy(IQueryable<T> data)
{
_data = data;
}
// These two CreateQuery methods get called by LINQ extension methods to build up the query
// and by ToFuture to return a queried collection and allow us to apply more filters
public IQueryable CreateQuery(Expression expression)
{
Type elementType = TypeSystem.GetElementType(expression.Type);
return (IQueryable)Activator.CreateInstance(typeof(NHibernateQueryableProxy<>)
.MakeGenericType(elementType), new object[] { this, expression });
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new NHibernateQueryableProxy<TElement>(this, expression);
}
// Those two Execute methods are called by terminal methods like .ToList() and .ToArray()
public object Execute(Expression expression)
{
return ExecuteInMemoryQuery(expression, false);
}
public TResult Execute<TResult>(Expression expression)
{
bool IsEnumerable = typeof(TResult).Name == "IEnumerable`1";
return (TResult)ExecuteInMemoryQuery(expression, IsEnumerable);
}
public object ExecuteFuture(Expression expression)
{
// Here we need to return a NhQueryProviderProxy so we can add more queries
// to the queryable and use another ToFuture if desired
return CreateQuery(expression);
}
private object ExecuteInMemoryQuery(Expression expression, bool isEnumerable)
{
var newExpr = new ExpressionTreeModifier<T>(_data).Visit(expression);
if (isEnumerable)
{
return _data.Provider.CreateQuery(newExpr);
}
return _data.Provider.Execute(newExpr);
}
public void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary<string, Tuple<object, IType>> parameters)
{
throw new NotImplementedException();
}
}
The expression tree visitor will change the type of the query for us:
internal class ExpressionTreeModifier<T> : ExpressionVisitor
{
private IQueryable<T> _queryableData;
internal ExpressionTreeModifier(IQueryable<T> queryableData)
{
_queryableData = queryableData;
}
protected override Expression VisitConstant(ConstantExpression c)
{
// Here the magic happens: the expression types are all NHibernateQueryableProxy,
// so we replace them by the correct ones
if (c.Type == typeof(NHibernateQueryableProxy<T>))
return Expression.Constant(_queryableData);
else
return c;
}
}
And we also need a helper (taken from the tutorial) to get the type being queried:
internal static class TypeSystem
{
internal static Type GetElementType(Type seqType)
{
Type ienum = FindIEnumerable(seqType);
if (ienum == null) return seqType;
return ienum.GetGenericArguments()[0];
}
private static Type FindIEnumerable(Type seqType)
{
if (seqType == null || seqType == typeof(string))
return null;
if (seqType.IsArray)
return typeof(IEnumerable<>).MakeGenericType(seqType.GetElementType());
if (seqType.IsGenericType)
{
foreach (Type arg in seqType.GetGenericArguments())
{
Type ienum = typeof(IEnumerable<>).MakeGenericType(arg);
if (ienum.IsAssignableFrom(seqType))
{
return ienum;
}
}
}
Type[] ifaces = seqType.GetInterfaces();
if (ifaces != null && ifaces.Length > 0)
{
foreach (Type iface in ifaces)
{
Type ienum = FindIEnumerable(iface);
if (ienum != null) return ienum;
}
}
if (seqType.BaseType != null && seqType.BaseType != typeof(object))
{
return FindIEnumerable(seqType.BaseType);
}
return null;
}
}
To test the above code, I ran the following snippet:
var arr = new NHibernateQueryableProxy<int>(Enumerable.Range(1, 10000).AsQueryable());
var fluentQuery = arr.Where(x => x > 1 && x < 4321443)
.Take(1000)
.Skip(3)
.Union(new[] { 4235, 24543, 52 })
.GroupBy(x => x.ToString().Length)
.ToFuture()
.ToList();
var linqQuery = (from n in arr
where n > 40 && n < 50
select n.ToString())
.ToFuture()
.ToList();
As I said, no complex scenarios were tested but I guess only a few tweaks will be necessary for real-world usages.
class Program
{
static void Main(string[] args)
{
var c = new SampleClass<ClassString>();
c.ClassStrings.Add(new ClassString{ Name1 = "1", Name2 = "1"});
c.ClassStrings.Add(new ClassString{ Name1 = "2", Name2 = "2"});
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2"));
Console.WriteLine(result);
Console.ReadLine();
}
}
public class ClassString
{
public string Name1 { get; set; }
public string Name2 { get; set; }
}
public interface ISampleQ
{
IQueryable<T> Query<T>() where T: class , new();
}
public class SampleClass<X> : ISampleQ
{
public List<X> ClassStrings { get; private set; }
public SampleClass()
{
ClassStrings = new List<X>();
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the WHERE expression from here.
return new EnumerableQuery<T>((IEnumerable<T>) ClassStrings);
}
}
I looked into this solution1, solution2 and solution3 seems not applicable to my question. Since the where clause was defined outside and it was interface of the class. How to get the expression inside Query method? since no variable being pass thru.
The purpose, I want the retrieve and injected back into the destination (which is DBContext as IQueryable). Because we have a common interface for us like this ISampleQ.
Added new sample codes but same scenario:
internal class Program
{
private static void Main(string[] args)
{
var oracleDbContext = new OracleDbContext();
var result = oracleDbContext.Query<Person>().Where(person => person.Name.Equals("username"));
Console.WriteLine();
Console.ReadLine();
}
}
public interface IGenericQuery
{
IQueryable<T> Query<T>() where T : class , new();
}
public class OracleDbContext : IGenericQuery
{
public OracleDbContext()
{
//Will hold all oracle operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here. Since the where was defined outside of the
//class. I want to retrieve since the IQueryable<T> is generic to both class
//OracleDbContext and MssqlDbContext. I want to re-inject the where or add
//new expression before calling.
//
//For eg.
//oracleDbContext.Query<T>(where clause from here)
return null;
}
}
public class MssqlDbContext : IGenericQuery
{
public MssqlDbContext()
{
//Will hold all MSSQL operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here.
return null;
}
}
public class Person
{
public int Id { get; set; }
public int Name { get; set; }
}
It is quite complex... Now... the Queryable.Where() works this way:
public static IQueryable<TSource> Where<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
return source.Provider.CreateQuery<TSource>(Expression.Call(null, ...
So the Queryable.Where calls the source.Provider.CreateQuery() that retuns a new IQueryable<>. So if you want to be able to "see" a Where() while it is being added (and manipulate it), you must "be" the IQueryable<>.Provider, and have your CreateQuery(), so you must create a class that implements IQueryProvider (and probably a class that implements IQueryable<T>).
Another way (much simpler) is to have a simple query "converter": a method that accepts a IQueryable<> and returns a manipulated IQueryable<>:
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2")).FixMyQuery();
As I said, the full route is quite long:
namespace Utilities
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
public class ProxyDbContext : DbContext
{
protected static readonly MethodInfo ProxifySetsMethod = typeof(ProxyDbContext).GetMethod("ProxifySets", BindingFlags.Instance | BindingFlags.NonPublic);
protected static class ProxyDbContexSetter<TContext> where TContext : ProxyDbContext
{
public static readonly Action<TContext> Do = x => { };
static ProxyDbContexSetter()
{
var properties = typeof(TContext).GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy);
ParameterExpression context = Expression.Parameter(typeof(TContext), "context");
FieldInfo manipulatorField = typeof(ProxyDbContext).GetField("Manipulator", BindingFlags.Instance | BindingFlags.Public);
Expression manipulator = Expression.Field(context, manipulatorField);
var sets = new List<Expression>();
foreach (PropertyInfo property in properties)
{
if (property.GetMethod == null)
{
continue;
}
MethodInfo setMethod = property.SetMethod;
if (setMethod != null && !setMethod.IsPublic)
{
continue;
}
Type type = property.PropertyType;
Type entityType = GetIDbSetTypeArgument(type);
if (entityType == null)
{
continue;
}
if (!type.IsAssignableFrom(typeof(DbSet<>).MakeGenericType(entityType)))
{
continue;
}
Type dbSetType = typeof(DbSet<>).MakeGenericType(entityType);
ConstructorInfo constructor = typeof(ProxyDbSet<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
dbSetType,
typeof(Func<bool, Expression, Expression>)
});
MemberExpression property2 = Expression.Property(context, property);
BinaryExpression assign = Expression.Assign(property2, Expression.New(constructor, Expression.Convert(property2, dbSetType), manipulator));
sets.Add(assign);
}
Expression<Action<TContext>> lambda = Expression.Lambda<Action<TContext>>(Expression.Block(sets), context);
Do = lambda.Compile();
}
// Gets the T of IDbSetlt;T>
private static Type GetIDbSetTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IDbSet<>)
select x.GetGenericArguments()[0]).SingleOrDefault();
return argument;
}
}
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(Func<bool, Expression, Expression> manipulator, bool resetSets = true)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
/// <summary>
///
/// </summary>
/// <param name="nameOrConnectionString"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(string nameOrConnectionString, Func<bool, Expression, Expression> manipulator, bool resetSets = true)
: base(nameOrConnectionString)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
protected void ProxifySets<TContext>() where TContext : ProxyDbContext
{
ProxyDbContexSetter<TContext>.Do((TContext)this);
}
public override DbSet<TEntity> Set<TEntity>()
{
return new ProxyDbSet<TEntity>(base.Set<TEntity>(), Manipulator);
}
public override DbSet Set(Type entityType)
{
DbSet set = base.Set(entityType);
ConstructorInfo constructor = typeof(ProxyDbSetNonGeneric<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
typeof(DbSet),
typeof(Func<bool, Expression, Expression>)
});
return (DbSet)constructor.Invoke(new object[] { set, Manipulator });
}
}
/// <summary>
/// The DbSet, that is implemented as InternalDbSet<> by EF.
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class ProxyDbSetNonGeneric<TEntity> : DbSet, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override object Add(object entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable AddRange(IEnumerable entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery AsNoTracking()
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override object Attach(object entity)
{
return BaseDbSet.Attach(entity);
}
public override object Create(Type derivedEntityType)
{
return BaseDbSet.Create(derivedEntityType);
}
public override object Create()
{
return BaseDbSet.Create();
}
public override object Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<object> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<object> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery Include(string path)
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.Include(path)), Manipulator);
}
public override IList Local
{
get
{
return BaseDbSet.Local;
}
}
public override object Remove(object entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable RemoveRange(IEnumerable entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
}
public class ProxyDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet<TEntity> BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet<TEntity>).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override TEntity Add(TEntity entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery<TEntity> AsNoTracking()
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery<TEntity> AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override TEntity Attach(TEntity entity)
{
return BaseDbSet.Attach(entity);
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return BaseDbSet.Create<TDerivedEntity>();
}
public override TEntity Create()
{
return BaseDbSet.Create();
}
public override TEntity Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<TEntity> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery<TEntity> Include(string path)
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.Include(path)), Manipulator);
}
public override ObservableCollection<TEntity> Local
{
get
{
return BaseDbSet.Local;
}
}
public override TEntity Remove(TEntity entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
// Note that the operator isn't virtual! If you do:
// DbSet<Foo> foo = new ProxyDbSet<Foo>(...)
// DbSet foo2 = (DbSet)foo;
// Then you'll have a non-proxed DbSet!
public static implicit operator ProxyDbSetNonGeneric<TEntity>(ProxyDbSet<TEntity> entry)
{
return new ProxyDbSetNonGeneric<TEntity>((DbSet)entry.BaseDbSet, entry.Manipulator);
}
}
public class ProxyDbProvider : IQueryProvider, IDbAsyncQueryProvider
{
protected readonly IQueryProvider BaseQueryProvider;
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="baseQueryProvider"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbProvider(IQueryProvider baseQueryProvider, Func<bool, Expression, Expression> manipulator)
{
BaseQueryProvider = baseQueryProvider;
Manipulator = manipulator;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable<TElement> query = BaseQueryProvider.CreateQuery<TElement>(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
return new ProxyQueryable<TElement>(proxy, query);
}
protected static readonly MethodInfo CreateQueryNonGenericToGenericMethod = typeof(ProxyDbProvider).GetMethod("CreateQueryNonGenericToGeneric", BindingFlags.Static | BindingFlags.NonPublic);
public IQueryable CreateQuery(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable query = BaseQueryProvider.CreateQuery(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
Type entityType = GetIQueryableTypeArgument(query.GetType());
if (entityType == null)
{
return new ProxyQueryable(proxy, query);
}
else
{
return (IQueryable)CreateQueryNonGenericToGenericMethod.MakeGenericMethod(entityType).Invoke(null, new object[] { proxy, query });
}
}
protected static ProxyQueryable<TElement> CreateQueryNonGenericToGeneric<TElement>(ProxyDbProvider proxy, IQueryable<TElement> query)
{
return new ProxyQueryable<TElement>(proxy, query);
}
public TResult Execute<TResult>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute<TResult>(expression2);
}
public object Execute(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute(expression2);
}
// Gets the T of IQueryablelt;T>
protected static Type GetIQueryableTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IQueryable<>)
select x.GetGenericArguments()[0]).FirstOrDefault();
return argument;
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync<TResult>(expression2, cancellationToken);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync(expression2, cancellationToken);
}
}
public class ProxyQueryable : IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
public class ProxyQueryable<TElement> : IOrderedQueryable<TElement>, IQueryable<TElement>, IEnumerable<TElement>, IDbAsyncEnumerable<TElement>, IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable<TElement> BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable<TElement> baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator<TElement> GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)BaseQueryable).GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
public IDbAsyncEnumerator<TElement> GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable<TElement>;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
}
An example of a manipulator of Expressions (this one will transform .Where(x => something) to .Where(x => something && something):
namespace My
{
using System.Linq;
using System.Linq.Expressions;
public class MyExpressionManipulator : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == "Where" && node.Arguments.Count == 2)
{
// Transforms all the .Where(x => something) in
// .Where(x => something && something)
if (node.Arguments[1].NodeType == ExpressionType.Quote)
{
UnaryExpression argument1 = (UnaryExpression)node.Arguments[1]; // Expression.Quote
if (argument1.Operand.NodeType == ExpressionType.Lambda)
{
LambdaExpression argument1lambda = (LambdaExpression)argument1.Operand;
// Important: at each step you'll reevalute the
// full expression! Try to not replace twice
// the expression!
// So if you have a query like:
// var res = ctx.Where(x => true).Where(x => true).Select(x => 1)
// the first time you'll visit
// ctx.Where(x => true)
// and you'll obtain
// ctx.Where(x => true && true)
// the second time you'll visit
// ctx.Where(x => true && true).Where(x => true)
// and you want to obtain
// ctx.Where(x => true && true).Where(x => true && true)
// and not
// ctx.Where(x => (true && true) && (true && true)).Where(x => true && true)
if (argument1lambda.Body.NodeType != ExpressionType.AndAlso)
{
var arguments = new Expression[node.Arguments.Count];
node.Arguments.CopyTo(arguments, 0);
arguments[1] = Expression.Quote(Expression.Lambda(Expression.AndAlso(argument1lambda.Body, argument1lambda.Body), argument1lambda.Parameters));
MethodCallExpression node2 = Expression.Call(node.Object, node.Method, arguments);
node = node2;
}
}
}
}
return base.VisitMethodCall(node);
}
}
}
Now... How to use it? The best way is to derive your context (in this case Model1) not from DbContext but from ProxyDbContext, like this:
public partial class Model1 : ProxyDbContext
{
public Model1()
: base("name=Model1", Manipulate)
{
}
/// <summary>
///
/// </summary>
/// <param name="executing">true: the returned Expression will be executed directly, false: the returned expression will be returned as IQueryable<>.</param>
/// <param name="expression"></param>
/// <returns></returns>
private static Expression Manipulate(bool executing, Expression expression)
{
// See the annotation about reexecuting the same visitor
// multiple times in MyExpressionManipulator().Visit .
// By executing the visitor only on executing == true,
// and simply return expression; on executing == false,
// you have the guarantee that an expression won't be
// manipulated multiple times.
// As written now, the expression will be manipulated
// multiple times.
return new MyExpressionManipulator().Visit(expression);
}
// Some tables
public virtual DbSet<Parent> Parent { get; set; }
public virtual IDbSet<Child> Child { get; set; }
Then it is very transparent:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
// The query is automatically manipulated by your Manipulate method
}
another way to do this without subclassing from ProxyDbContext:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
Func<Expression, Expression> manipulator = new MyExpressionManipulator().Visit;
ctx.Parent = new ProxyDbSet<Parent>(ctx.Parent, manipulator);
ctx.Child = new ProxyDbSet<Child>(ctx.Child, manipulator);
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
}
The ProxyDbContext<> replaces the DbSet<>/IDbSet<> that are present in your context with some ProxyDbSet<>.
In the second example this action is done explicitly, but note taht you can create a method to do it, or create a factory for your context (a static method that returns a context with the various DbSet<> "proxied"), or you could put the proxification in the constructor of your context (because the "original" initialization of the DbSet<> happens in the constructor of DbContext, and the body of the constructor of your context is executed after this), or you could create multiple subclasses of your context, each one that has the constructor that proxifies in a different way...
Note that the first method (subclassing ProxyDbContext<>) "fixes" the Set<>/Set methods that otherwise you'll have to fix yourself by copying the code of the overloads of these two methodsd from ProxyDbContext<>.
for access control purposes in a intensive DB use system I had to implement an objectset wrapper, where the AC will be checked.
The main objective is make this change preserving the existing code for database access, that is implemented with linq to entities all over the classes (there is no centralized layer for database).
The ObjectSetWrapper created is like that:
public class ObjectSetWrapper<TEntity> : IQueryable<TEntity> where TEntity : EntityObject
{
private IQueryable<TEntity> QueryableModel;
private ObjectSet<TEntity> ObjectSet;
public ObjectSetWrapper(ObjectSet<TEntity> objectSetModels)
{
this.QueryableModel = objectSetModels;
this.ObjectSet = objectSetModels;
}
public ObjectQuery<TEntity> Include(string path)
{
return this.ObjectSet.Include(path);
}
public void DeleteObject(TEntity #object)
{
this.ObjectSet.DeleteObject(#object);
}
public void AddObject(TEntity #object)
{
this.ObjectSet.AddObject(#object);
}
public IEnumerator<TEntity> GetEnumerator()
{
return QueryableModel.GetEnumerator();
}
public Type ElementType
{
get { return typeof(TEntity); }
}
public System.Linq.Expressions.Expression Expression
{
get { return this.QueryableModel.Expression; }
}
public IQueryProvider Provider
{
get { return this.QueryableModel.Provider; }
}
public void Attach(TEntity entity)
{
this.ObjectSet.Attach(entity);
}
public void Detach(TEntity entity)
{
this.ObjectSet.Detach(entity);
}
IEnumerator IEnumerable.GetEnumerator()
{
return this.QueryableModel.GetEnumerator();
}
}
It's really simple and works for simple queries, like that:
//db.Product is ObjectSetWrapper<Product>
var query = (from item in db.Product where item.Quantity > 0 select new { item.Id, item.Name, item.Value });
var itensList = query.Take(10).ToList();
But when I have subqueries like that:
//db.Product is ObjectSetWrapper<Product>
var query = (from item in db.Product
select new
{
Id = item.Id,
Name = item.Name,
SalesQuantity = (from sale in db.Sale where sale.ProductId == item.Id select sale.Id).Count()
}).OrderByDescending(x => x.SalesQuantity);
var productsList = query.Take(10).ToList();
I get NotSupportedException, saying I can't create a constant value of my inner query entity type:
Unable to create a constant value of type 'MyNamespace.Model.Sale'.
Only primitive types or enumeration types are supported in this
context.
How can I get my queries working? I don't really need to make my wrapper an ObjectSet type, I just need to use it in queries.
Updated
I have changed my class signature. Now it's also implementing IObjectSet<>, but I'm getting the same NotSupportedException:
public class ObjectSetWrapper<TEntity> : IQueryable<TEntity>, IObjectSet<TEntity> where TEntity : EntityObject
EDIT:
The problem is that the following LINQ construction is translated into LINQ expression containing your custom class inside (ObjectSetWrapper).
var query = (from item in db.Product
select new
{
Id = item.Id,
Name = item.Name,
SalesQuantity = (from sale in db.Sale where sale.ProductId == item.Id select sale.Id).Count()
}).OrderByDescending(x => x.SalesQuantity);
LINQ to Entities tries to convert this expression into SQL statement, but it has no idea how to deal with the custom classes (as well as custom methods).
The solution in such cases is to replace IQueryProvider with the custom one, which should intercept the query execution and translate LINQ expression, containing custom classes/methods into valid LINQ to Entities expression (which operates with entities and object sets).
Expression conversion is performed using the class, derived from ExpressionVisitor, which performs expression tree traversal, replacing relevant nodes, to the nodes which can be accepted by LINQ to Entities
Part 1 - IQueryWrapper
// Query wrapper interface - holds and underlying query
interface IQueryWrapper
{
IQueryable UnderlyingQueryable { get; }
}
Part 2 - Abstract QueryWrapperBase (not generic)
abstract class QueryWrapperBase : IQueryProvider, IQueryWrapper
{
public IQueryable UnderlyingQueryable { get; private set; }
class ObjectWrapperReplacer : ExpressionVisitor
{
public override Expression Visit(Expression node)
{
if (node == null || !typeof(IQueryWrapper).IsAssignableFrom(node.Type)) return base.Visit(node);
var wrapper = EvaluateExpression<IQueryWrapper>(node);
return Expression.Constant(wrapper.UnderlyingQueryable);
}
public static Expression FixExpression(Expression expression)
{
var replacer = new ObjectWrapperReplacer();
return replacer.Visit(expression);
}
private T EvaluateExpression<T>(Expression expression)
{
if (expression is ConstantExpression) return (T)((ConstantExpression)expression).Value;
var lambda = Expression.Lambda(expression);
return (T)lambda.Compile().DynamicInvoke();
}
}
protected QueryWrapperBase(IQueryable underlyingQueryable)
{
UnderlyingQueryable = underlyingQueryable;
}
public abstract IQueryable<TElement> CreateQuery<TElement>(Expression expression);
public abstract IQueryable CreateQuery(Expression expression);
public TResult Execute<TResult>(Expression expression)
{
return (TResult)Execute(expression);
}
public object Execute(Expression expression)
{
expression = ObjectWrapperReplacer.FixExpression(expression);
return typeof(IQueryable).IsAssignableFrom(expression.Type)
? ExecuteQueryable(expression)
: ExecuteNonQueryable(expression);
}
protected object ExecuteNonQueryable(Expression expression)
{
return UnderlyingQueryable.Provider.Execute(expression);
}
protected IQueryable ExecuteQueryable(Expression expression)
{
return UnderlyingQueryable.Provider.CreateQuery(expression);
}
}
Part 3 - Generic QueryWrapper<TElement>
class QueryWrapper<TElement> : QueryWrapperBase, IOrderedQueryable<TElement>
{
private static readonly MethodInfo MethodCreateQueryDef = GetMethodDefinition(q => q.CreateQuery<object>(null));
public QueryWrapper(IQueryable<TElement> underlyingQueryable) : this(null, underlyingQueryable)
{
}
protected QueryWrapper(Expression expression, IQueryable underlyingQueryable) : base(underlyingQueryable)
{
Expression = expression ?? Expression.Constant(this);
}
public virtual IEnumerator<TElement> GetEnumerator()
{
return ((IEnumerable<TElement>)Execute<IEnumerable>(Expression)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public Expression Expression { get; private set; }
public Type ElementType
{
get { return typeof(TElement); }
}
public IQueryProvider Provider
{
get { return this; }
}
public override IQueryable CreateQuery(Expression expression)
{
var expressionType = expression.Type;
var elementType = expressionType
.GetInterfaces()
.Single(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>))
.GetGenericArguments()
.Single();
var createQueryMethod = MethodCreateQueryDef.MakeGenericMethod(elementType);
return (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
}
public override IQueryable<TNewElement> CreateQuery<TNewElement>(Expression expression)
{
return new QueryWrapper<TNewElement>(expression, UnderlyingQueryable);
}
private static MethodInfo GetMethodDefinition(Expression<Action<QueryWrapper<TElement>>> methodSelector)
{
var methodCallExpression = (MethodCallExpression)methodSelector.Body;
return methodCallExpression.Method.GetGenericMethodDefinition();
}
}
Part 4 - finally your ObjectSetWrapper
public class ObjectSetWrapper<TEntity> : IQueryable<TEntity>, IQueryWrapper where TEntity : class
{
private IQueryable<TEntity> QueryableModel;
private ObjectSet<TEntity> ObjectSet;
public ObjectSetWrapper(ObjectSet<TEntity> objectSetModels)
{
this.QueryableModel = new QueryWrapper<TEntity>(objectSetModels);
this.ObjectSet = objectSetModels;
}
public ObjectQuery<TEntity> Include(string path)
{
return this.ObjectSet.Include(path);
}
public void DeleteObject(TEntity #object)
{
this.ObjectSet.DeleteObject(#object);
}
public void AddObject(TEntity #object)
{
this.ObjectSet.AddObject(#object);
}
public IEnumerator<TEntity> GetEnumerator()
{
return QueryableModel.GetEnumerator();
}
public Type ElementType
{
get { return typeof(TEntity); }
}
public System.Linq.Expressions.Expression Expression
{
get { return this.QueryableModel.Expression; }
}
public IQueryProvider Provider
{
get { return this.QueryableModel.Provider; }
}
public void Attach(TEntity entity)
{
this.ObjectSet.Attach(entity);
}
public void Detach(TEntity entity)
{
this.ObjectSet.Detach(entity);
}
IEnumerator IEnumerable.GetEnumerator()
{
return this.QueryableModel.GetEnumerator();
}
IQueryable IQueryWrapper.UnderlyingQueryable
{
get { return this.ObjectSet; }
}
}
Your inner query fails because you are referencing another dataset when you should be traversing foreign keys:
SalesQuantity = item.Sales.Count()
For some reason Microsoft decided to not support simple concat in EF5.
e.g.
Select(foo => new
{
someProp = "hello" + foo.id + "/" + foo.bar
}
This will throw if foo.id or foo.bar are numbers.
The workaround I've found is apparently this pretty peice of code:
Select(foo => new
{
someProp = "hello" +
SqlFunctions.StringConvert((double?)foo.id).Trim() +
"/" +
SqlFunctions.StringConvert((double?)foo.bar).Trim()
}
Which works fine, but is just horrid to look at.
So, is there some decent way to accomplish this with cleaner code?
I'm NOT interested in doing this client side, so no .AsEnumerable() answers please.
For those interested.
I got so pissed with the lack of this feature that I implemented it myself using an ExpressionVisitor.
You can now write code like the one in the original question.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Objects.SqlClient;
using System.Linq;
using System.Linq.Expressions;
namespace Crawlr.Web.Code
{
public static class ObjectSetExExtensions
{
public static ObjectSetEx<T> Extend<T>(this IQueryable<T> self) where T : class
{
return new ObjectSetEx<T>(self);
}
}
public class ObjectSetEx<T> : IOrderedQueryable<T>
{
private readonly QueryProviderEx provider;
private readonly IQueryable<T> source;
public ObjectSetEx(IQueryable<T> source)
{
this.source = source;
provider = new QueryProviderEx(this.source.Provider);
}
#region IQueryableEx<T> Members
public IEnumerator<T> GetEnumerator()
{
return source.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return source.GetEnumerator();
}
public Type ElementType
{
get { return source.ElementType; }
}
public Expression Expression
{
get { return source.Expression; }
}
public IQueryProvider Provider
{
get { return provider; }
}
#endregion
}
public class QueryProviderEx : IQueryProvider
{
private readonly IQueryProvider source;
public QueryProviderEx(IQueryProvider source)
{
this.source = source;
}
#region IQueryProvider Members
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression);
IQueryable<TElement> query = source.CreateQuery<TElement>(newExpression);
return new ObjectSetEx<TElement>(query);
}
public IQueryable CreateQuery(Expression expression)
{
Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression);
IQueryable query = source.CreateQuery(newExpression);
return query;
}
public TResult Execute<TResult>(Expression expression)
{
Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression);
return source.Execute<TResult>(newExpression);
}
public object Execute(Expression expression)
{
Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression);
return source.Execute(newExpression);
}
#endregion
}
public class ExpressionReWriterVisitor : ExpressionVisitor
{
public static readonly ExpressionReWriterVisitor Default = new ExpressionReWriterVisitor();
protected override Expression VisitUnary(UnaryExpression node)
{
if (node.NodeType == ExpressionType.Convert && node.Operand.Type == typeof(int) && node.Type==typeof(object))
{
var operand = node.Operand;
var stringConvertMethod = typeof(SqlFunctions).GetMethod("StringConvert", new Type[] { typeof(double?) });
var trimMethod = typeof(string).GetMethod("Trim",new Type[] {});
var dOperand = Expression.Convert(operand, typeof(double?));
return Expression.Call(Expression.Call(stringConvertMethod, dOperand),trimMethod);
}
return base.VisitUnary(node);
}
}
}
usage:
var res = model
.FooSet
.Extend() //<- applies the magic
.Select(foo => new
{
someProp = "hello" + foo.id + "/" + foo.bar
}