How to get the where clause from IQueryable defined as interface - c#
class Program
{
static void Main(string[] args)
{
var c = new SampleClass<ClassString>();
c.ClassStrings.Add(new ClassString{ Name1 = "1", Name2 = "1"});
c.ClassStrings.Add(new ClassString{ Name1 = "2", Name2 = "2"});
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2"));
Console.WriteLine(result);
Console.ReadLine();
}
}
public class ClassString
{
public string Name1 { get; set; }
public string Name2 { get; set; }
}
public interface ISampleQ
{
IQueryable<T> Query<T>() where T: class , new();
}
public class SampleClass<X> : ISampleQ
{
public List<X> ClassStrings { get; private set; }
public SampleClass()
{
ClassStrings = new List<X>();
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the WHERE expression from here.
return new EnumerableQuery<T>((IEnumerable<T>) ClassStrings);
}
}
I looked into this solution1, solution2 and solution3 seems not applicable to my question. Since the where clause was defined outside and it was interface of the class. How to get the expression inside Query method? since no variable being pass thru.
The purpose, I want the retrieve and injected back into the destination (which is DBContext as IQueryable). Because we have a common interface for us like this ISampleQ.
Added new sample codes but same scenario:
internal class Program
{
private static void Main(string[] args)
{
var oracleDbContext = new OracleDbContext();
var result = oracleDbContext.Query<Person>().Where(person => person.Name.Equals("username"));
Console.WriteLine();
Console.ReadLine();
}
}
public interface IGenericQuery
{
IQueryable<T> Query<T>() where T : class , new();
}
public class OracleDbContext : IGenericQuery
{
public OracleDbContext()
{
//Will hold all oracle operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here. Since the where was defined outside of the
//class. I want to retrieve since the IQueryable<T> is generic to both class
//OracleDbContext and MssqlDbContext. I want to re-inject the where or add
//new expression before calling.
//
//For eg.
//oracleDbContext.Query<T>(where clause from here)
return null;
}
}
public class MssqlDbContext : IGenericQuery
{
public MssqlDbContext()
{
//Will hold all MSSQL operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here.
return null;
}
}
public class Person
{
public int Id { get; set; }
public int Name { get; set; }
}
It is quite complex... Now... the Queryable.Where() works this way:
public static IQueryable<TSource> Where<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
return source.Provider.CreateQuery<TSource>(Expression.Call(null, ...
So the Queryable.Where calls the source.Provider.CreateQuery() that retuns a new IQueryable<>. So if you want to be able to "see" a Where() while it is being added (and manipulate it), you must "be" the IQueryable<>.Provider, and have your CreateQuery(), so you must create a class that implements IQueryProvider (and probably a class that implements IQueryable<T>).
Another way (much simpler) is to have a simple query "converter": a method that accepts a IQueryable<> and returns a manipulated IQueryable<>:
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2")).FixMyQuery();
As I said, the full route is quite long:
namespace Utilities
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
public class ProxyDbContext : DbContext
{
protected static readonly MethodInfo ProxifySetsMethod = typeof(ProxyDbContext).GetMethod("ProxifySets", BindingFlags.Instance | BindingFlags.NonPublic);
protected static class ProxyDbContexSetter<TContext> where TContext : ProxyDbContext
{
public static readonly Action<TContext> Do = x => { };
static ProxyDbContexSetter()
{
var properties = typeof(TContext).GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy);
ParameterExpression context = Expression.Parameter(typeof(TContext), "context");
FieldInfo manipulatorField = typeof(ProxyDbContext).GetField("Manipulator", BindingFlags.Instance | BindingFlags.Public);
Expression manipulator = Expression.Field(context, manipulatorField);
var sets = new List<Expression>();
foreach (PropertyInfo property in properties)
{
if (property.GetMethod == null)
{
continue;
}
MethodInfo setMethod = property.SetMethod;
if (setMethod != null && !setMethod.IsPublic)
{
continue;
}
Type type = property.PropertyType;
Type entityType = GetIDbSetTypeArgument(type);
if (entityType == null)
{
continue;
}
if (!type.IsAssignableFrom(typeof(DbSet<>).MakeGenericType(entityType)))
{
continue;
}
Type dbSetType = typeof(DbSet<>).MakeGenericType(entityType);
ConstructorInfo constructor = typeof(ProxyDbSet<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
dbSetType,
typeof(Func<bool, Expression, Expression>)
});
MemberExpression property2 = Expression.Property(context, property);
BinaryExpression assign = Expression.Assign(property2, Expression.New(constructor, Expression.Convert(property2, dbSetType), manipulator));
sets.Add(assign);
}
Expression<Action<TContext>> lambda = Expression.Lambda<Action<TContext>>(Expression.Block(sets), context);
Do = lambda.Compile();
}
// Gets the T of IDbSetlt;T>
private static Type GetIDbSetTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IDbSet<>)
select x.GetGenericArguments()[0]).SingleOrDefault();
return argument;
}
}
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(Func<bool, Expression, Expression> manipulator, bool resetSets = true)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
/// <summary>
///
/// </summary>
/// <param name="nameOrConnectionString"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(string nameOrConnectionString, Func<bool, Expression, Expression> manipulator, bool resetSets = true)
: base(nameOrConnectionString)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
protected void ProxifySets<TContext>() where TContext : ProxyDbContext
{
ProxyDbContexSetter<TContext>.Do((TContext)this);
}
public override DbSet<TEntity> Set<TEntity>()
{
return new ProxyDbSet<TEntity>(base.Set<TEntity>(), Manipulator);
}
public override DbSet Set(Type entityType)
{
DbSet set = base.Set(entityType);
ConstructorInfo constructor = typeof(ProxyDbSetNonGeneric<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
typeof(DbSet),
typeof(Func<bool, Expression, Expression>)
});
return (DbSet)constructor.Invoke(new object[] { set, Manipulator });
}
}
/// <summary>
/// The DbSet, that is implemented as InternalDbSet<> by EF.
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class ProxyDbSetNonGeneric<TEntity> : DbSet, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override object Add(object entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable AddRange(IEnumerable entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery AsNoTracking()
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override object Attach(object entity)
{
return BaseDbSet.Attach(entity);
}
public override object Create(Type derivedEntityType)
{
return BaseDbSet.Create(derivedEntityType);
}
public override object Create()
{
return BaseDbSet.Create();
}
public override object Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<object> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<object> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery Include(string path)
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.Include(path)), Manipulator);
}
public override IList Local
{
get
{
return BaseDbSet.Local;
}
}
public override object Remove(object entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable RemoveRange(IEnumerable entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
}
public class ProxyDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet<TEntity> BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet<TEntity>).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override TEntity Add(TEntity entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery<TEntity> AsNoTracking()
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery<TEntity> AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override TEntity Attach(TEntity entity)
{
return BaseDbSet.Attach(entity);
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return BaseDbSet.Create<TDerivedEntity>();
}
public override TEntity Create()
{
return BaseDbSet.Create();
}
public override TEntity Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<TEntity> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery<TEntity> Include(string path)
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.Include(path)), Manipulator);
}
public override ObservableCollection<TEntity> Local
{
get
{
return BaseDbSet.Local;
}
}
public override TEntity Remove(TEntity entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
// Note that the operator isn't virtual! If you do:
// DbSet<Foo> foo = new ProxyDbSet<Foo>(...)
// DbSet foo2 = (DbSet)foo;
// Then you'll have a non-proxed DbSet!
public static implicit operator ProxyDbSetNonGeneric<TEntity>(ProxyDbSet<TEntity> entry)
{
return new ProxyDbSetNonGeneric<TEntity>((DbSet)entry.BaseDbSet, entry.Manipulator);
}
}
public class ProxyDbProvider : IQueryProvider, IDbAsyncQueryProvider
{
protected readonly IQueryProvider BaseQueryProvider;
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="baseQueryProvider"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbProvider(IQueryProvider baseQueryProvider, Func<bool, Expression, Expression> manipulator)
{
BaseQueryProvider = baseQueryProvider;
Manipulator = manipulator;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable<TElement> query = BaseQueryProvider.CreateQuery<TElement>(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
return new ProxyQueryable<TElement>(proxy, query);
}
protected static readonly MethodInfo CreateQueryNonGenericToGenericMethod = typeof(ProxyDbProvider).GetMethod("CreateQueryNonGenericToGeneric", BindingFlags.Static | BindingFlags.NonPublic);
public IQueryable CreateQuery(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable query = BaseQueryProvider.CreateQuery(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
Type entityType = GetIQueryableTypeArgument(query.GetType());
if (entityType == null)
{
return new ProxyQueryable(proxy, query);
}
else
{
return (IQueryable)CreateQueryNonGenericToGenericMethod.MakeGenericMethod(entityType).Invoke(null, new object[] { proxy, query });
}
}
protected static ProxyQueryable<TElement> CreateQueryNonGenericToGeneric<TElement>(ProxyDbProvider proxy, IQueryable<TElement> query)
{
return new ProxyQueryable<TElement>(proxy, query);
}
public TResult Execute<TResult>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute<TResult>(expression2);
}
public object Execute(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute(expression2);
}
// Gets the T of IQueryablelt;T>
protected static Type GetIQueryableTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IQueryable<>)
select x.GetGenericArguments()[0]).FirstOrDefault();
return argument;
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync<TResult>(expression2, cancellationToken);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync(expression2, cancellationToken);
}
}
public class ProxyQueryable : IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
public class ProxyQueryable<TElement> : IOrderedQueryable<TElement>, IQueryable<TElement>, IEnumerable<TElement>, IDbAsyncEnumerable<TElement>, IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable<TElement> BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable<TElement> baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator<TElement> GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)BaseQueryable).GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
public IDbAsyncEnumerator<TElement> GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable<TElement>;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
}
An example of a manipulator of Expressions (this one will transform .Where(x => something) to .Where(x => something && something):
namespace My
{
using System.Linq;
using System.Linq.Expressions;
public class MyExpressionManipulator : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == "Where" && node.Arguments.Count == 2)
{
// Transforms all the .Where(x => something) in
// .Where(x => something && something)
if (node.Arguments[1].NodeType == ExpressionType.Quote)
{
UnaryExpression argument1 = (UnaryExpression)node.Arguments[1]; // Expression.Quote
if (argument1.Operand.NodeType == ExpressionType.Lambda)
{
LambdaExpression argument1lambda = (LambdaExpression)argument1.Operand;
// Important: at each step you'll reevalute the
// full expression! Try to not replace twice
// the expression!
// So if you have a query like:
// var res = ctx.Where(x => true).Where(x => true).Select(x => 1)
// the first time you'll visit
// ctx.Where(x => true)
// and you'll obtain
// ctx.Where(x => true && true)
// the second time you'll visit
// ctx.Where(x => true && true).Where(x => true)
// and you want to obtain
// ctx.Where(x => true && true).Where(x => true && true)
// and not
// ctx.Where(x => (true && true) && (true && true)).Where(x => true && true)
if (argument1lambda.Body.NodeType != ExpressionType.AndAlso)
{
var arguments = new Expression[node.Arguments.Count];
node.Arguments.CopyTo(arguments, 0);
arguments[1] = Expression.Quote(Expression.Lambda(Expression.AndAlso(argument1lambda.Body, argument1lambda.Body), argument1lambda.Parameters));
MethodCallExpression node2 = Expression.Call(node.Object, node.Method, arguments);
node = node2;
}
}
}
}
return base.VisitMethodCall(node);
}
}
}
Now... How to use it? The best way is to derive your context (in this case Model1) not from DbContext but from ProxyDbContext, like this:
public partial class Model1 : ProxyDbContext
{
public Model1()
: base("name=Model1", Manipulate)
{
}
/// <summary>
///
/// </summary>
/// <param name="executing">true: the returned Expression will be executed directly, false: the returned expression will be returned as IQueryable<>.</param>
/// <param name="expression"></param>
/// <returns></returns>
private static Expression Manipulate(bool executing, Expression expression)
{
// See the annotation about reexecuting the same visitor
// multiple times in MyExpressionManipulator().Visit .
// By executing the visitor only on executing == true,
// and simply return expression; on executing == false,
// you have the guarantee that an expression won't be
// manipulated multiple times.
// As written now, the expression will be manipulated
// multiple times.
return new MyExpressionManipulator().Visit(expression);
}
// Some tables
public virtual DbSet<Parent> Parent { get; set; }
public virtual IDbSet<Child> Child { get; set; }
Then it is very transparent:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
// The query is automatically manipulated by your Manipulate method
}
another way to do this without subclassing from ProxyDbContext:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
Func<Expression, Expression> manipulator = new MyExpressionManipulator().Visit;
ctx.Parent = new ProxyDbSet<Parent>(ctx.Parent, manipulator);
ctx.Child = new ProxyDbSet<Child>(ctx.Child, manipulator);
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
}
The ProxyDbContext<> replaces the DbSet<>/IDbSet<> that are present in your context with some ProxyDbSet<>.
In the second example this action is done explicitly, but note taht you can create a method to do it, or create a factory for your context (a static method that returns a context with the various DbSet<> "proxied"), or you could put the proxification in the constructor of your context (because the "original" initialization of the DbSet<> happens in the constructor of DbContext, and the body of the constructor of your context is executed after this), or you could create multiple subclasses of your context, each one that has the constructor that proxifies in a different way...
Note that the first method (subclassing ProxyDbContext<>) "fixes" the Set<>/Set methods that otherwise you'll have to fix yourself by copying the code of the overloads of these two methodsd from ProxyDbContext<>.
Related
How to correct replace type of Expression?
I have two classes: public class DalMembershipUser { public string UserName { get; set; } //other members } public class MembershipUser { public string UserName { get; set; } //other members } I have function: public IEnumerable<DalMembershipUser> GetMany(Expression<Func<DalMembershipUser, bool>> predicate) { //but here i can use only Func<MembershipUser, bool> //so i made transformation query = query.Where(ExpressionTransformer<DalMembershipUser,MembershipUser>.Tranform(predicate)); } Current implementation: public static class ExpressionTransformer<TFrom, TTo> { public class Visitor : ExpressionVisitor { private ParameterExpression _targetParameterExpression; public Visitor(ParameterExpression parameter) { _targetParameterExpression = parameter; } protected override Expression VisitParameter(ParameterExpression node) { return _targetParameterExpression; } } public static Expression<Func<TTo, bool>> Tranform(Expression<Func<TFrom, bool>> expression) { ParameterExpression parameter = Expression.Parameter(typeof(TTo), expression.Parameters[0].Name); Expression body = expression.Body; new Visitor(parameter).Visit(expression.Body); return Expression.Lambda<Func<TTo, bool>>(body, parameter); } } //Somewhere: .GetMany(u => u.UserName == "username"); Exception: Property 'System.String UserName' is not defined for type 'MembershipUser' at line: new Visitor(parameter).Visit(expression.Body);
Finally it work. But still don't understand why clear parameters creation: Expression.Parameter(typeof(TTo), from.Parameters[i].Name); not work and need to extract. public static class ExpressionHelper { public static Expression<Func<TTo, bool>> TypeConvert<TFrom, TTo>( this Expression<Func<TFrom, bool>> from) { if (from == null) return null; return ConvertImpl<Func<TFrom, bool>, Func<TTo, bool>>(from); } private static Expression<TTo> ConvertImpl<TFrom, TTo>(Expression<TFrom> from) where TFrom : class where TTo : class { // figure out which types are different in the function-signature var fromTypes = from.Type.GetGenericArguments(); var toTypes = typeof(TTo).GetGenericArguments(); if (fromTypes.Length != toTypes.Length) throw new NotSupportedException("Incompatible lambda function-type signatures"); Dictionary<Type, Type> typeMap = new Dictionary<Type, Type>(); for (int i = 0; i < fromTypes.Length; i++) { if (fromTypes[i] != toTypes[i]) typeMap[fromTypes[i]] = toTypes[i]; } // re-map all parameters that involve different types Dictionary<Expression, Expression> parameterMap = new Dictionary<Expression, Expression>(); ParameterExpression[] newParams = GenerateParameterMap<TFrom>(from, typeMap, parameterMap); // rebuild the lambda var body = new TypeConversionVisitor<TTo>(parameterMap).Visit(from.Body); return Expression.Lambda<TTo>(body, newParams); } private static ParameterExpression[] GenerateParameterMap<TFrom>( Expression<TFrom> from, Dictionary<Type, Type> typeMap, Dictionary<Expression, Expression> parameterMap ) where TFrom : class { var newParams = new ParameterExpression[from.Parameters.Count]; for (int i = 0; i < newParams.Length; i++) { Type newType; if (typeMap.TryGetValue(from.Parameters[i].Type, out newType)) { parameterMap[from.Parameters[i]] = newParams[i] = Expression.Parameter(newType, from.Parameters[i].Name); } } return newParams; } class TypeConversionVisitor<T> : ExpressionVisitor { private readonly Dictionary<Expression, Expression> parameterMap; public TypeConversionVisitor(Dictionary<Expression, Expression> parameterMap) { this.parameterMap = parameterMap; } protected override Expression VisitParameter(ParameterExpression node) { // re-map the parameter Expression found; if (!parameterMap.TryGetValue(node, out found)) found = base.VisitParameter(node); return found; } public override Expression Visit(Expression node) { LambdaExpression lambda = node as LambdaExpression; if (lambda != null && !parameterMap.ContainsKey(lambda.Parameters.First())) { return new TypeConversionVisitor<T>(parameterMap).Visit(lambda.Body); } return base.Visit(node); } protected override Expression VisitMember(MemberExpression node) { // re-perform any member-binding var expr = Visit(node.Expression); if (expr.Type != node.Type) { if (expr.Type.GetMember(node.Member.Name).Any()) { MemberInfo newMember = expr.Type.GetMember(node.Member.Name).Single(); return Expression.MakeMemberAccess(expr, newMember); } } return base.VisitMember(node); } } }
You need to use the result expression returned by the Visit method. Just change: Expression body = expression.Body; new Visitor(parameter).Visit(expression.Body); by Expression body = new Visitor(parameter).Visit(expression.Body);
ObjectSet wrapper not working with linqToEntities subquery
for access control purposes in a intensive DB use system I had to implement an objectset wrapper, where the AC will be checked. The main objective is make this change preserving the existing code for database access, that is implemented with linq to entities all over the classes (there is no centralized layer for database). The ObjectSetWrapper created is like that: public class ObjectSetWrapper<TEntity> : IQueryable<TEntity> where TEntity : EntityObject { private IQueryable<TEntity> QueryableModel; private ObjectSet<TEntity> ObjectSet; public ObjectSetWrapper(ObjectSet<TEntity> objectSetModels) { this.QueryableModel = objectSetModels; this.ObjectSet = objectSetModels; } public ObjectQuery<TEntity> Include(string path) { return this.ObjectSet.Include(path); } public void DeleteObject(TEntity #object) { this.ObjectSet.DeleteObject(#object); } public void AddObject(TEntity #object) { this.ObjectSet.AddObject(#object); } public IEnumerator<TEntity> GetEnumerator() { return QueryableModel.GetEnumerator(); } public Type ElementType { get { return typeof(TEntity); } } public System.Linq.Expressions.Expression Expression { get { return this.QueryableModel.Expression; } } public IQueryProvider Provider { get { return this.QueryableModel.Provider; } } public void Attach(TEntity entity) { this.ObjectSet.Attach(entity); } public void Detach(TEntity entity) { this.ObjectSet.Detach(entity); } IEnumerator IEnumerable.GetEnumerator() { return this.QueryableModel.GetEnumerator(); } } It's really simple and works for simple queries, like that: //db.Product is ObjectSetWrapper<Product> var query = (from item in db.Product where item.Quantity > 0 select new { item.Id, item.Name, item.Value }); var itensList = query.Take(10).ToList(); But when I have subqueries like that: //db.Product is ObjectSetWrapper<Product> var query = (from item in db.Product select new { Id = item.Id, Name = item.Name, SalesQuantity = (from sale in db.Sale where sale.ProductId == item.Id select sale.Id).Count() }).OrderByDescending(x => x.SalesQuantity); var productsList = query.Take(10).ToList(); I get NotSupportedException, saying I can't create a constant value of my inner query entity type: Unable to create a constant value of type 'MyNamespace.Model.Sale'. Only primitive types or enumeration types are supported in this context. How can I get my queries working? I don't really need to make my wrapper an ObjectSet type, I just need to use it in queries. Updated I have changed my class signature. Now it's also implementing IObjectSet<>, but I'm getting the same NotSupportedException: public class ObjectSetWrapper<TEntity> : IQueryable<TEntity>, IObjectSet<TEntity> where TEntity : EntityObject
EDIT: The problem is that the following LINQ construction is translated into LINQ expression containing your custom class inside (ObjectSetWrapper). var query = (from item in db.Product select new { Id = item.Id, Name = item.Name, SalesQuantity = (from sale in db.Sale where sale.ProductId == item.Id select sale.Id).Count() }).OrderByDescending(x => x.SalesQuantity); LINQ to Entities tries to convert this expression into SQL statement, but it has no idea how to deal with the custom classes (as well as custom methods). The solution in such cases is to replace IQueryProvider with the custom one, which should intercept the query execution and translate LINQ expression, containing custom classes/methods into valid LINQ to Entities expression (which operates with entities and object sets). Expression conversion is performed using the class, derived from ExpressionVisitor, which performs expression tree traversal, replacing relevant nodes, to the nodes which can be accepted by LINQ to Entities Part 1 - IQueryWrapper // Query wrapper interface - holds and underlying query interface IQueryWrapper { IQueryable UnderlyingQueryable { get; } } Part 2 - Abstract QueryWrapperBase (not generic) abstract class QueryWrapperBase : IQueryProvider, IQueryWrapper { public IQueryable UnderlyingQueryable { get; private set; } class ObjectWrapperReplacer : ExpressionVisitor { public override Expression Visit(Expression node) { if (node == null || !typeof(IQueryWrapper).IsAssignableFrom(node.Type)) return base.Visit(node); var wrapper = EvaluateExpression<IQueryWrapper>(node); return Expression.Constant(wrapper.UnderlyingQueryable); } public static Expression FixExpression(Expression expression) { var replacer = new ObjectWrapperReplacer(); return replacer.Visit(expression); } private T EvaluateExpression<T>(Expression expression) { if (expression is ConstantExpression) return (T)((ConstantExpression)expression).Value; var lambda = Expression.Lambda(expression); return (T)lambda.Compile().DynamicInvoke(); } } protected QueryWrapperBase(IQueryable underlyingQueryable) { UnderlyingQueryable = underlyingQueryable; } public abstract IQueryable<TElement> CreateQuery<TElement>(Expression expression); public abstract IQueryable CreateQuery(Expression expression); public TResult Execute<TResult>(Expression expression) { return (TResult)Execute(expression); } public object Execute(Expression expression) { expression = ObjectWrapperReplacer.FixExpression(expression); return typeof(IQueryable).IsAssignableFrom(expression.Type) ? ExecuteQueryable(expression) : ExecuteNonQueryable(expression); } protected object ExecuteNonQueryable(Expression expression) { return UnderlyingQueryable.Provider.Execute(expression); } protected IQueryable ExecuteQueryable(Expression expression) { return UnderlyingQueryable.Provider.CreateQuery(expression); } } Part 3 - Generic QueryWrapper<TElement> class QueryWrapper<TElement> : QueryWrapperBase, IOrderedQueryable<TElement> { private static readonly MethodInfo MethodCreateQueryDef = GetMethodDefinition(q => q.CreateQuery<object>(null)); public QueryWrapper(IQueryable<TElement> underlyingQueryable) : this(null, underlyingQueryable) { } protected QueryWrapper(Expression expression, IQueryable underlyingQueryable) : base(underlyingQueryable) { Expression = expression ?? Expression.Constant(this); } public virtual IEnumerator<TElement> GetEnumerator() { return ((IEnumerable<TElement>)Execute<IEnumerable>(Expression)).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public Expression Expression { get; private set; } public Type ElementType { get { return typeof(TElement); } } public IQueryProvider Provider { get { return this; } } public override IQueryable CreateQuery(Expression expression) { var expressionType = expression.Type; var elementType = expressionType .GetInterfaces() .Single(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>)) .GetGenericArguments() .Single(); var createQueryMethod = MethodCreateQueryDef.MakeGenericMethod(elementType); return (IQueryable)createQueryMethod.Invoke(this, new object[] { expression }); } public override IQueryable<TNewElement> CreateQuery<TNewElement>(Expression expression) { return new QueryWrapper<TNewElement>(expression, UnderlyingQueryable); } private static MethodInfo GetMethodDefinition(Expression<Action<QueryWrapper<TElement>>> methodSelector) { var methodCallExpression = (MethodCallExpression)methodSelector.Body; return methodCallExpression.Method.GetGenericMethodDefinition(); } } Part 4 - finally your ObjectSetWrapper public class ObjectSetWrapper<TEntity> : IQueryable<TEntity>, IQueryWrapper where TEntity : class { private IQueryable<TEntity> QueryableModel; private ObjectSet<TEntity> ObjectSet; public ObjectSetWrapper(ObjectSet<TEntity> objectSetModels) { this.QueryableModel = new QueryWrapper<TEntity>(objectSetModels); this.ObjectSet = objectSetModels; } public ObjectQuery<TEntity> Include(string path) { return this.ObjectSet.Include(path); } public void DeleteObject(TEntity #object) { this.ObjectSet.DeleteObject(#object); } public void AddObject(TEntity #object) { this.ObjectSet.AddObject(#object); } public IEnumerator<TEntity> GetEnumerator() { return QueryableModel.GetEnumerator(); } public Type ElementType { get { return typeof(TEntity); } } public System.Linq.Expressions.Expression Expression { get { return this.QueryableModel.Expression; } } public IQueryProvider Provider { get { return this.QueryableModel.Provider; } } public void Attach(TEntity entity) { this.ObjectSet.Attach(entity); } public void Detach(TEntity entity) { this.ObjectSet.Detach(entity); } IEnumerator IEnumerable.GetEnumerator() { return this.QueryableModel.GetEnumerator(); } IQueryable IQueryWrapper.UnderlyingQueryable { get { return this.ObjectSet; } } }
Your inner query fails because you are referencing another dataset when you should be traversing foreign keys: SalesQuantity = item.Sales.Count()
Making a custom class IQueryable
I have been working with the TFS API for VS2010 and had to query FieldCollection which I found isn't supported by LINQ so I wanted to create a custom class to make the Field and FieldCollection queryable by LINQ so I found a basic template and tried to implement it using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using Microsoft.TeamFoundation.WorkItemTracking.Client; public class WorkItemFieldCollection : IQueryable<Field>, IQueryProvider { private List<Field> _fieldList = new List<Field>(); #region Constructors /// <summary> /// This constructor is called by the client to create the data source. /// </summary> public WorkItemFieldCollection(FieldCollection fieldCollection) { foreach (Field field in fieldCollection) { _fieldList.Add(field); } } #endregion Constructors #region IQueryable Members Type IQueryable.ElementType { get { return typeof(Field); } } System.Linq.Expressions.Expression IQueryable.Expression { get { return Expression.Constant(this); } } IQueryProvider IQueryable.Provider { get { return this; } } #endregion IQueryable Members #region IEnumerable<Field> Members IEnumerator<Field> IEnumerable<Field>.GetEnumerator() { return (this as IQueryable).Provider.Execute<IEnumerator<Field>>(_expression); } private IList<Field> _field = new List<Field>(); private Expression _expression = null; #endregion IEnumerable<Field> Members #region IEnumerable Members System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { return (IEnumerator<Field>)(this as IQueryable).GetEnumerator(); } private void ProcessExpression(Expression expression) { if (expression.NodeType == ExpressionType.Equal) { ProcessEqualResult((BinaryExpression)expression); } if (expression is UnaryExpression) { UnaryExpression uExp = expression as UnaryExpression; ProcessExpression(uExp.Operand); } else if (expression is LambdaExpression) { ProcessExpression(((LambdaExpression)expression).Body); } else if (expression is ParameterExpression) { if (((ParameterExpression)expression).Type == typeof(Field)) { _field = GetFields(); } } } private void ProcessEqualResult(BinaryExpression expression) { if (expression.Right.NodeType == ExpressionType.Constant) { string name = (String)((ConstantExpression)expression.Right).Value; ProceesItem(name); } } private void ProceesItem(string name) { IList<Field> filtered = new List<Field>(); foreach (Field field in GetFields()) { if (string.Compare(field.Name, name, true) == 0) { filtered.Add(field); } } _field = filtered; } private object GetValue(BinaryExpression expression) { if (expression.Right.NodeType == ExpressionType.Constant) { return ((ConstantExpression)expression.Right).Value; } return null; } private IList<Field> GetFields() { return _fieldList; } #endregion IEnumerable Members #region IQueryProvider Members IQueryable<S> IQueryProvider.CreateQuery<S>(System.Linq.Expressions.Expression expression) { if (typeof(S) != typeof(Field)) throw new Exception("Only " + typeof(Field).FullName + " objects are supported."); this._expression = expression; return (IQueryable<S>)this; } IQueryable IQueryProvider.CreateQuery(System.Linq.Expressions.Expression expression) { return (IQueryable<Field>)(this as IQueryProvider).CreateQuery<Field>(expression); } TResult IQueryProvider.Execute<TResult>(System.Linq.Expressions.Expression expression) { MethodCallExpression methodcall = _expression as MethodCallExpression; foreach (var param in methodcall.Arguments) { ProcessExpression(param); } return (TResult)_field.GetEnumerator(); } object IQueryProvider.Execute(System.Linq.Expressions.Expression expression) { return (this as IQueryProvider).Execute<IEnumerator<Field>>(expression); } #endregion IQueryProvider Members } It appeared to compile and was recognized by LINQ but i keep getting an error in the CreateQuery method because it passes in string and not a field IQueryable<S> IQueryProvider.CreateQuery<S>(System.Linq.Expressions.Expression expression) { if (typeof(S) != typeof(Field)) throw new Exception("Only " + typeof(Field).FullName + " objects are supported."); this._expression = expression; return (IQueryable<S>)this; } here is the Linq query I use... columnFilterList is List and fields is my custom FieldCollection class see above. foreach (var name in columnFilterList) { var fieldName = (from x in fields where x.Name == name select x.Name).First } ....I sure it is a simple mistake...could someone tell me what I am doing wrong...thanks
If you want an object to be usable by LINQ, implement IEnumerable<T>. IQueryable<T> is overkill for LINQ to Objects. It is designed for converting the expressions into another form. Or if you want, you can do this FieldCollection someFieldCollection = ... IEnumerable<Field> fields = someFieldCollections.Cast<Field>();
In your case of wrapping and building a class around an existing IEnumerable Collection type i.e. List<Field>, you might just use a set of "forward function" wrappers that expose the IQueryable<Field> interface: public class WorkItemFieldCollection : IEnumerable<Field>, IQueryable<Field> { ... #region Implementation of IQueryable<Field> public Type ElementType { get { return this._fieldList.AsQueryable().ElementType; } } public Expression Expression { get { return this._fieldList.AsQueryable().Expression; } } public IQueryProvider Provider { get { return this._fieldList.AsQueryable().Provider; } } #endregion ... } You can now directly query your WorkItemFieldCollection: var workItemFieldCollection = new WorkItemFieldCollection(...); var Fuzz = "someStringId"; var workFieldItem = workItemFieldCollection.FirstOrDefault( c => c.Buzz == Fuzz );
Best way to concat strings and numbers in SQL server using Entity Framework 5?
For some reason Microsoft decided to not support simple concat in EF5. e.g. Select(foo => new { someProp = "hello" + foo.id + "/" + foo.bar } This will throw if foo.id or foo.bar are numbers. The workaround I've found is apparently this pretty peice of code: Select(foo => new { someProp = "hello" + SqlFunctions.StringConvert((double?)foo.id).Trim() + "/" + SqlFunctions.StringConvert((double?)foo.bar).Trim() } Which works fine, but is just horrid to look at. So, is there some decent way to accomplish this with cleaner code? I'm NOT interested in doing this client side, so no .AsEnumerable() answers please.
For those interested. I got so pissed with the lack of this feature that I implemented it myself using an ExpressionVisitor. You can now write code like the one in the original question. using System; using System.Collections; using System.Collections.Generic; using System.Data.Objects.SqlClient; using System.Linq; using System.Linq.Expressions; namespace Crawlr.Web.Code { public static class ObjectSetExExtensions { public static ObjectSetEx<T> Extend<T>(this IQueryable<T> self) where T : class { return new ObjectSetEx<T>(self); } } public class ObjectSetEx<T> : IOrderedQueryable<T> { private readonly QueryProviderEx provider; private readonly IQueryable<T> source; public ObjectSetEx(IQueryable<T> source) { this.source = source; provider = new QueryProviderEx(this.source.Provider); } #region IQueryableEx<T> Members public IEnumerator<T> GetEnumerator() { return source.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return source.GetEnumerator(); } public Type ElementType { get { return source.ElementType; } } public Expression Expression { get { return source.Expression; } } public IQueryProvider Provider { get { return provider; } } #endregion } public class QueryProviderEx : IQueryProvider { private readonly IQueryProvider source; public QueryProviderEx(IQueryProvider source) { this.source = source; } #region IQueryProvider Members public IQueryable<TElement> CreateQuery<TElement>(Expression expression) { Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression); IQueryable<TElement> query = source.CreateQuery<TElement>(newExpression); return new ObjectSetEx<TElement>(query); } public IQueryable CreateQuery(Expression expression) { Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression); IQueryable query = source.CreateQuery(newExpression); return query; } public TResult Execute<TResult>(Expression expression) { Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression); return source.Execute<TResult>(newExpression); } public object Execute(Expression expression) { Expression newExpression = ExpressionReWriterVisitor.Default.Visit(expression); return source.Execute(newExpression); } #endregion } public class ExpressionReWriterVisitor : ExpressionVisitor { public static readonly ExpressionReWriterVisitor Default = new ExpressionReWriterVisitor(); protected override Expression VisitUnary(UnaryExpression node) { if (node.NodeType == ExpressionType.Convert && node.Operand.Type == typeof(int) && node.Type==typeof(object)) { var operand = node.Operand; var stringConvertMethod = typeof(SqlFunctions).GetMethod("StringConvert", new Type[] { typeof(double?) }); var trimMethod = typeof(string).GetMethod("Trim",new Type[] {}); var dOperand = Expression.Convert(operand, typeof(double?)); return Expression.Call(Expression.Call(stringConvertMethod, dOperand),trimMethod); } return base.VisitUnary(node); } } } usage: var res = model .FooSet .Extend() //<- applies the magic .Select(foo => new { someProp = "hello" + foo.id + "/" + foo.bar }
When using a repository is it possible for a type to return a Func that the repository uses to test for existing entities?
For example given a Factory with a method public static T Save<T>(T item) where T : Base, new() { /* item.Id == Guid.Empty therefore item is new */ if (item.Id == Guid.Empty && repository.GetAll<T>(t => t.Name == item.Name)) { throw new Exception("Name is not unique"); } } how do I create a property of Base (say MustNotAlreadyExist) so that I can change the method above to public static T Save<T>(T item) where T : Base, new() { /* item.Id == Guid.Empty therefore item is new */ if (item.Id == Guid.Empty && repository.GetAll<T>(t.MustNotAlreadyExist)) { throw new Exception("Name is not unique"); } } public class Base { ... public virtual Expression<Func<T, bool>> MustNotAlreadyExist() { return (b => b.Name == name); /* <- this clearly doesn't work */ } } and then how can I override MustNotAlreadyExist in Account : Base public class Account : Base { ... public override Expression<Func<T, bool>> MustNotAlreadyExist() { return (b => b.Name == name && b.AccountCode == accountCode); /* <- this doesn't work */ } ... }
Try this: public class Account : Base { ... public override Expression<Func<T, bool>> MustNotAlreadyExist() { return (b => b.Name == name && b.AccountCode == accountCode).Any(); } ... } The Any() method will return true if any record matches the predicate. It could be argued that it is outside the responsibility of the repository to check for presence of a record before saving. UPDATE: There is a great article on CodeProject that describes a generic Repository for Entity Framework: http://www.codeproject.com/KB/database/ImplRepositoryPatternEF.aspx This could be applied to a non-Entity Framework data context. Here is an excerpt that provides a very flexible method for checking for an existing value by accepting the name of a field, a value, and a key value. You can apply this to any Entity type and use it to check for the presence of an entity before attempting a save. /// <summary> /// Check if value of specific field is already exist /// </summary> /// <typeparam name="E"></typeparam> /// <param name="fieldName">name of the Field</param> /// <param name="fieldValue">Field value</param> /// <param name="key">Primary key value</param> /// <returns>True or False</returns> public bool TrySameValueExist(string fieldName, object fieldValue, string key) { // First we define the parameter that we are going to use the clause. var xParam = Expression.Parameter(typeof(E), typeof(E).Name); MemberExpression leftExprFieldCheck = MemberExpression.Property(xParam, fieldName); Expression rightExprFieldCheck = Expression.Constant(fieldValue); BinaryExpression binaryExprFieldCheck = MemberExpression.Equal(leftExprFieldCheck, rightExprFieldCheck); MemberExpression leftExprKeyCheck = MemberExpression.Property(xParam, this._KeyProperty); Expression rightExprKeyCheck = Expression.Constant(key); BinaryExpression binaryExprKeyCheck = MemberExpression.NotEqual(leftExprKeyCheck, rightExprKeyCheck); BinaryExpression finalBinaryExpr = Expression.And(binaryExprFieldCheck, binaryExprKeyCheck); //Create Lambda Expression for the selection Expression<Func<E, bool>> lambdaExpr = Expression.Lambda<Func<E, bool>>(finalBinaryExpr, new ParameterExpression[] { xParam }); //Searching .... return ((IRepository<E, C>)this).TryEntity(new Specification<E>(lambdaExpr)); } /// <summary> /// Check if Entities exist with Condition /// </summary> /// <param name="selectExpression">Selection Condition</param> /// <returns>True or False</returns> public bool TryEntity(ISpecification<E> selectSpec) { return _ctx.CreateQuery<E>("[" + typeof(E).Name + "]").Any<E> (selectSpec.EvalPredicate); }
I am not shure if you problem is solvable since you need to access both the repository and the new item to be checked. The new item to be checked is not available in a seperate method. However you can outsource the call to GetAll so that your code becomes something similar to (not tested) public static T Save<T>(T item) where T : Base, new() { if (item.Id == Guid.Empty && (Check(repository, item))) { throw new Exception("Name is not unique"); } } public class Base { ... public Func<Enumerable<T>, T, bool> Check { get; set;} public Base() { Check = (col, newItem) => (null != col.FirstOrDefault<T>( item => item.Name == newItem.Name)); } }
OK, here is the answer, this is a combination of the code posted by Dave Swersky and a little but of common sense. public interface IUniqueable<T> { Expression<Func<T, bool>> Unique { get; } } public class Base, IUniqueable<Base> { ... public Expression<Func<Base, bool>> Unique { get { var xParam = Expression.Parameter(typeof(Base), typeof(Base).Name); MemberExpression leftExprFieldCheck = MemberExpression.Property(xParam, "Name"); Expression rightExprFieldCheck = Expression.Constant(this.Name); BinaryExpression binaryExprFieldCheck = MemberExpression.Equal(leftExprFieldCheck, rightExprFieldCheck); return Expression.Lambda<Func<Base, bool>>(binaryExprFieldCheck, new ParameterExpression[] { xParam }); } } ... } public class Account : Base, IUniqueable<Account> { ... public new Expression<Func<Account, bool>> Unique { get { var xParam = Expression.Parameter(typeof(Account), typeof(Account).Name); MemberExpression leftExprNameCheck = MemberExpression.Property(xParam, "Name"); Expression rightExprNameCheck = Expression.Constant(this.Name); BinaryExpression binaryExprNameCheck = MemberExpression.Equal(leftExprNameCheck, rightExprNameCheck); MemberExpression leftExprFieldCheck = MemberExpression.Property(xParam, "AccountCode"); Expression rightExprFieldCheck = Expression.Constant(this.AccountCode); BinaryExpression binaryExprFieldCheck = MemberExpression.Equal(leftExprFieldCheck, rightExprFieldCheck); BinaryExpression binaryExprAllCheck = Expression.OrElse(binaryExprNameCheck, binaryExprFieldCheck); return Expression.Lambda<Func<Account, bool>>(binaryExprAllCheck, new ParameterExpression[] { xParam }); } } ... } public static class Manager { ... public static T Save<T>(T item) where T : Base, new() { if (!item.IsValid) { throw new ValidationException("Unable to save item, item is not valid", item.GetRuleViolations()); } if (item.Id == Guid.Empty && repository.GetAll<T>().Any(((IUniqueable<T>)item).Unique)) { throw new Exception("Item is not unique"); } return repository.Save<T>(item); } ... } Essentially by implementing the IUniqueable interface for a specific type I can return a different Expression for each type. All good :-)