I have this static class:
namespace Dapper
{
public static class SqlMapper
{
public static T ExecuteScalar<T>(this IDbConnection cnn, string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null)
{
CommandDefinition command = new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered, new CancellationToken());
return SqlMapper.ExecuteScalarImpl<T>(cnn, ref command);
}
public static Task<int> ExecuteAsync(this IDbConnection cnn, string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null)
{
return cnn.ExecuteAsync(new CommandDefinition(sql, param, transaction, commandTimeout, commandType, CommandFlags.Buffered, new CancellationToken()));
}
//...
//Goes on for another 200 different static functions
}
}
I want to create some wrapper class that will hold a default value of commandTimeout. I don't want it as a global parameter, I want this class to be build in bootstrapper with this value:
using Dapper;
public class SqlWrapper : ISqlWrapper
{
private readonly ILogger _logger;
private readonly int _commandTimeoutInSec;
public SqlWrapper(ILogger logger, int commandTimeoutInSec)
{
_logger = logger;
_commandTimeoutInSec = commandTimeoutInSec;
}
public async ExecuteScalarAsync<T>(IDbConnection cnn, string sql, object param = null,
int? commandTimeout = null, CommandType? commandType = null)
{
try
{
using (var conn = new SqlConnection(cnn))
{
var commandGuid =
await conn.ExecuteScalarAsync<T>(sql, param, CommandType: CommandType, commandTimeout: commandTimeout ?? _commandTimeoutInSec);
return commandGuid;
}
}
catch (Exception ex)
{
_logger.WriteError(
$"Job execute failed with error:", ex);
throw;
}
}
public async int ExecuteAsync(IDbConnection cnn, string sql, object param = null, int? commandTimeout = null, CommandType? commandType = null)
{
try
{
using (var conn = new SqlConnection(cnn))
{
var commandGuid =
await
conn.ExecuteAsync(sql, param, CommandType: CommandType,
commandTimeout: commandTimeout ?? _commandTimeoutInSec);
return commandGuid;
}
}
catch (Exception ex)
{
_logger.WriteError(
$"Job execute failed with error:", ex);
throw;
}
}
//...
//Goes on for the rest 200 different static functions in SqlMapper
}
the thing is, I feel like it's stupid to implement a wrapper for 200 functions to to pass a default parameter, is there a way to receive the name of the function as a template of parameter and then pass on the call to the same name in SqlMapper?
You may apply the following technique to your scenario. Imagine we had this static class:
public static class Printer
{
public static void Print(string output, int numberOfTimes)
{
for (int i = 0; i < numberOfTimes; i++)
{
Console.WriteLine(output);
}
}
public static void Show(string output, int numberOfTimes)
{
for (int i = 0; i < numberOfTimes; i++)
{
Console.WriteLine(output);
}
}
}
And let's imagine we wanted to give the numberOfTimes a default value, then we can do this:
public class DefaultPrinter
{
private int defaultTimes = 10;
public void ExecuteMethod(Action<string, int> action, string output)
{
action(output, defaultTimes);
}
}
Now the users will not have to indicate how many times the print should occur because the above class will do it for a default of 10 times.
Usage
var printer = new DefaultPrinter();
printer.ExecuteMethod((a, b) => Printer.Print(a, b), "One");
printer.ExecuteMethod((a, b) => Printer.Show(a, b), "Two");
Related
I have an ASP.NET Core 3.1 Web API project which is leveraging Dapper for database operations. In this case I am trying to leverage Polly to to add resiliency to the methods that are connecting to the database while fetching the data.
I referenced the below article for my POC:
https://concurrentflows.hashnode.dev/basic-dapper-resiliency-using-polly
This is my code:
public interface ISqlDapperClient
{
Task<int> ExecuteAsync(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null);
Task<T> ExecuteScalarAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null);
Task<IEnumerable<T>> QueryAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null);
Task<T> QueryFirstOrDefaultAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null);
Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TReturn>(string sql, Func<TFirst, TSecond, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null);
Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TThird, TReturn>(string sql, Func<TFirst, TSecond, TThird, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null);
}
public class SqlDapperClient : ISqlDapperClient
{
private readonly ILogger<SqlDapperClient> logger;
private readonly string _dbConnection;
private readonly IConfiguration _configuration;
private readonly IDBAuthTokenService _dbTokenService;
private readonly IDbConnection connection;
private readonly IAsyncPolicy resiliencyPolicy;
//private AsyncRetryPolicy retryPolicy;
public SqlDapperClient(ILogger<SqlDapperClient> logger, IAsyncPolicy resiliencyPolicy, IConfiguration configuration, IDBAuthTokenService dbTokenService)
{
this.logger = logger ?? throw new ArgumentNullException(nameof(logger));
this.resiliencyPolicy = resiliencyPolicy ?? throw new ArgumentNullException(nameof(resiliencyPolicy));
_configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
// Read the connectionstring
_dbConnection = _configuration[C.VaultKeys.DataDBConnString] ?? _configuration[C.AppKeys.LocalDataDBConn];
_dbTokenService = dbTokenService ?? throw new ArgumentNullException(nameof(dbTokenService));
connection = OpenConnectionWithRetryAsync().Result;
}
/// <summary>
/// Method that returns IDbConnection to connect with database
/// </summary>
/// <returns>IDbConnection</returns>
private async Task<IDbConnection> OpenConnectionWithRetryAsync()
{
var conn = new SqlConnection(_dbConnection)
{ AccessToken = await _dbTokenService.GetTokenAsync() };
await conn.OpenAsync();
return conn;
}
public Task<int> ExecuteAsync(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.ExecuteAsync(s, p, transaction, commandTimeout, commandType), sql, param);
public Task<T> ExecuteScalarAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.ExecuteScalarAsync<T>(s, p, transaction, commandTimeout, commandType), sql, param);
public Task<T> QueryFirstOrDefaultAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.QueryFirstOrDefaultAsync<T>(s, p, transaction, commandTimeout, commandType), sql, param);
public Task<IEnumerable<T>> QueryAsync<T>(string sql, object param = null, IDbTransaction transaction = null, int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.QueryAsync<T>(s, p, transaction, commandTimeout, commandType), sql, param);
public Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TReturn>(string sql, Func<TFirst, TSecond, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.QueryAsync(s, map, p, transaction, buffered, splitOn, commandTimeout, commandType), sql, param);
public Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TThird, TReturn>(string sql, Func<TFirst, TSecond, TThird, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null) => ExecuteWithResiliency((s, p, c) => c.QueryAsync(s, map, p, transaction, buffered, splitOn, commandTimeout, commandType), sql, param);
private async Task<T> ExecuteWithResiliency<T>(Func<string, object, SqlConnection, Task<T>> connectionFunc, string sql, object param = null, [CallerMemberName] string operation = "")
{
return await resiliencyPolicy.ExecuteAsync(ctx => connectionFunc(sql, param, (SqlConnection)connection), ContextHelper.NewContext((SqlConnection)connection, logger, sql, param, operation));
}
}
SqlResiliencyPolicy:
public static class SqlResiliencyPolicy
{
private static readonly ISet<int> transientNumbers = new HashSet<int>(new[]{40613, 40197, 40501, 49918, 40549, 40550, 1205});
private static readonly ISet<int> networkingNumbers = new HashSet<int>(new[]{258, -2, 10060, 0, 64, 26, 40, 10053});
private static readonly ISet<int> constraintViolationNumbers = new HashSet<int>(new[]{2627, 547, 2601});
public static IAsyncPolicy GetSqlResiliencyPolicy(TimeSpan? maxTimeout = null, int transientRetries = 3, int networkRetries = 3)
{
var timeoutPolicy = Policy.TimeoutAsync(maxTimeout ?? TimeSpan.FromMinutes(2));
var transientPolicy = Policy.Handle<SqlException>(ex => transientNumbers.Contains(ex.Number)).WaitAndRetryAsync(transientRetries, attempt => TimeSpan.FromSeconds(Math.Pow(2, attempt)), (ex, _, ctx) => ctx.GetLogger()?.LogWarning(ex, "{#Operation} Encountered Transient SqlException. Params:{#Param} Sql:{#Sql}", ctx.OperationKey, ctx[ContextHelper.ParamContextKey], ctx[ContextHelper.SqlContextKey]));
var networkPolicy = Policy.Handle<SqlException>(ex => networkingNumbers.Contains(ex.Number)).WaitAndRetryAsync(networkRetries, attempt => TimeSpan.FromSeconds(Math.Pow(2, attempt)), (ex, _, ctx) =>
{
ctx.GetLogger()?.LogWarning(ex, "{#Operation} Encountered a Network Error. Params:{#Param} Sql:{#Sql}", ctx.OperationKey, ctx[ContextHelper.ParamContextKey], ctx[ContextHelper.SqlContextKey]);
if (ctx.TryGetConnection(out var connection))
SqlConnection.ClearPool(connection);
});
var constraintPolicy = Policy.Handle<SqlException>(ex => constraintViolationNumbers.Contains(ex.Number)).CircuitBreakerAsync(1, TimeSpan.MaxValue, (ex, _, ctx) => ctx.GetLogger()?.LogError(ex, "{#Operation} Encountered a Constraint Violation. Params:{#Param} Sql:{#Sql}", ctx.OperationKey, ctx[ContextHelper.ParamContextKey], ctx[ContextHelper.SqlContextKey]), ctx =>
{
});
var resiliencyPolicy = timeoutPolicy.WrapAsync(transientPolicy).WrapAsync(networkPolicy).WrapAsync(constraintPolicy);
return resiliencyPolicy;
}
}
ContextHelper:
public static class ContextHelper
{
public static readonly string LoggerContextKey = nameof(LoggerContextKey);
public static readonly string SqlContextKey = nameof(SqlContextKey);
public static readonly string ParamContextKey = nameof(ParamContextKey);
public static readonly string ConnectionContextKey = nameof(ConnectionContextKey);
public static Polly.Context NewContext(SqlConnection connection, ILogger logger, string sql, object param, string operationKey)
{
return new Polly.Context(operationKey, new Dictionary<string, object>()
{{ConnectionContextKey, connection}, {LoggerContextKey, logger}, {SqlContextKey, sql}, {ParamContextKey, param}});
}
public static ILogger GetLogger(this Polly.Context ctx) => ctx[LoggerContextKey] as ILogger;
public static bool TryGetConnection(this Polly.Context ctx, out SqlConnection connection) => (connection = ctx[ConnectionContextKey] as SqlConnection) != null;
}
Startup.cs (for DI setup)
public static void AddSqlDapperClient(this IServiceCollection services, string connectionString)
{
services.AddScoped(_ => SqlResiliencyPolicy.GetSqlResiliencyPolicy());
services.AddScoped<ISqlDapperClient, SqlDapperClient>();
}
On running the code at runtime I am getting an error :
Unable to resolve service for type 'Polly.IAsyncPolicy' while attempting to activate 'SqlDapperClient'
Can anyone help me here by providing their guidance?
You need to register your service for IAsyncPolicy with specific scope in Startup.cs file.
eg.
services.AddScoped<IAsyncPolicy, AsyncPolicy>();
I was trying to implement a wrapper in C# for SQL Server.
The normal workflow without wrapper is fetching the data into a datatable using direct SQL query and then mapping the columns by names into entities.
But as a wrapper is better to accept a mapping function which describes which column maps to which fields of an enumerable.
So, something like this :
public class UserInfo
{
public string FirstName{ get; set; }
public string LastName{ get; set; }
}
enumerableList = dbManager.Execute("** sql query **", /* some method to specify mapping */);
The enumerable will then contain the result from the database, mapped by the execute method. But I am unsure how to specify the mapping?
Even if I do then how to deal with the different data types for each column in the mapping?
If I correct understand, you want something like this:
public static List<T> ReadRows<T>(this SqlHelper sql, string query, SqlParameter[]
parameters, Func<SqlDataReader, T> projection)
{
var command = GetSqlCommand(query, CommandType.StoredProcedure, parameters);
return sql.ExecuteReader(command, reader => reader.Select(projection).ToList());
}
And use like:
var members = _unitOfWork.SqlHelper.ReadRows("spGetMembersByUserCompanies", parameters, _memberProjection);
readonly Func<SqlDataReader, MemberVm> _memberProjection = (r) => new MemberVm
{
InvitationId = r.Get<int?>("InvitationId"),
UserName = r.Get<string>("UserName"),
RoleName = r.Get<string>("RoleName"),
InvitationStatus = (InvitationStatus)r.Get<int>("InvitationStatus"),
LogoUrl = r.Get<string>("LogoUrl")
};
It is a piece of my code. I hope it is start to resolve your problem.
Implementing such a wrapper from bare bones is not that easy. But it is possible. There is an ADO wrapper library in Github : ADOWrapper
The implementation is pretty straightforward.
Short Answer
How to specify mapping between columns? - Use Func
How to deal with the different data types? You can write an extension
Long Answer
Make a generic method that takes as input the query and a Func Delegate (and optional third parameter to pass query parameters as dictionary)
public ICollection<T> Execute<T>(string query, Func<IDataReader, T> map, IDictionary<string, object> parameters = null)
{
ICollection<T> collection = new List<T>();
using (SqlConnection connection = CreateConnection())
{
connection.Open();
using (SqlCommand command = CreateCommand(connection, query, parameters))
{
using (IDataReader reader = await command.ExecuteReader())
{
while(reader.Read())
{
collection.Add(map.Invoke(reader));
}
}
}
connection.Close();
}
return collection;
}
Implementation of AddParameter AND CreateCommand:
private void AddParameter(IDbCommand command, string parameter, object value)
{
IDbDataParameter param = command.CreateParameter();
param.ParameterName = parameter;
param.Value = value;
command.Parameters.Add(param);
}
private SqlCommand CreateCommand(SqlConnection connection, string query,
IDictionary<string, object> parameters = null)
{
SqlCommand command = connection.CreateCommand();
command.CommandText = query;
if(parameters != null && parameters.Count > 0)
{
foreach(KeyValuePair<string, object> parameter in parameters)
{
AddParameter(command, parameter.Key, parameter.Value);
}
}
return command;
}
You can call the method like this :
public class UserInfo
{
public string FirstName{ get; set; }
public string LastName{ get; set; }
}
var enumerableList = manager.Execute("** query **",
(reader) =>
{
return new UserInfo()
{
FirstName = reader.Get<string>("FirstName"),
LastName = reader.Get<string>("LastName "),
};
})
The Get method makes it easy to manage different data types being fetched from column. But it is not an inbuilt method. So you need to write an extension for Data Reader:
public static class DataReaderExtension
{
public static T Get<T>(this IDataReader reader, string column) where T : IComparable
{
try
{
int index = reader.GetOrdinal(column);
if (!reader.IsDBNull(index))
{
return (T)reader[index];
}
}
catch (IndexOutOfRangeException) { throw new Exception($"Column, '{column}' not found."); }
return default(T);
}
public static IEnumerable<string> GetColumns(this IDataReader reader)
{
IEnumerable<string> columns = new List<string>();
if (reader != null && reader.FieldCount > 0)
{
columns = Enumerable.Range(0, reader.FieldCount)
.Select(index => reader.GetName(index))
.ToList();
}
return columns;
}
}
Below is my extension for multi mapping (one to many relationship) in dapper
public static IEnumerable<TParent> QueryParentChild<TParent, TChild, TParentKey>(
this IDbConnection connection,
string sql,
Func<TParent, TParentKey> parentKeySelector,
Func<TParent, IList<TChild>> childSelector,
dynamic param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null)
{
Dictionary<TParentKey, TParent> cache = new Dictionary<TParentKey, TParent>();
connection.Query<TParent, TChild, TParent>(
sql,
(parent, child) =>
{
if (!cache.ContainsKey(parentKeySelector(parent)))
{
cache.Add(parentKeySelector(parent), parent);
}
TParent cachedParent = cache[parentKeySelector(parent)];
IList<TChild> children = childSelector(cachedParent);
children.Add(child);
return cachedParent;
},
param as object, transaction, buffered, splitOn, commandTimeout, commandType);
return cache.Values;
}
Now i want to convert this to async method. i have tried many ways.But got some errors..Pls let me know the changes need to be done
Did you try something like this, see below:
public static async Task<IEnumerable<TParent>> QueryParentChildAsync<TParent,
TChild,
TParentKey>(
this IDbConnection connection,
string sql,
Func<TParent, TParentKey> parentKeySelector,
Func<TParent, IList<TChild>> childSelector,
dynamic param = null,
IDbTransaction transaction = null,
bool buffered = true,
string splitOn = "Id",
int? commandTimeout = null,
CommandType? commandType = null)
{
var cache = new Dictionary<TParentKey, TParent>();
await connection.QueryAsync<TParent, TChild, TParent>(
sql,
(parent, child) =>
{
var key = parentKeySelector(parent);
if (!cache.ContainsKey(key ))
{
cache.Add(key, parent);
}
var cachedParent = cache[key];
var children = childSelector(cachedParent);
children.Add(child);
return cachedParent;
},
param as object,
transaction,
buffered,
splitOn,
commandTimeout,
commandType);
return cache.Values;
}
I have this class, in which I am wrapping dapper calls in order to do something like
var results = SqlWrapper.ExecuteQuery<Product,Customer>("SELECT id FROM Products; SELECT id FROM Customers;");
Where
results[0] = List<Product>
results[1] = List<Customer>
I support 1,2,3 output objects, but would like arbitrary. The class is also ugly and full of copy and pasted code. I account for if I want to reuse a connection by optionally passing a connection but the code just seems unclean. What I would really like is a way to define params T[] but as I understand that doesnt work. Is this any way this code can be cleaned/shortened?
using System.Collections.Generic;
using System.Data.SqlClient;
using System.Linq;
using Dapper;
namespace SqlWrapper
{
public static class SqlWrapper
{
private const string SqlConnectionString = "Server=localhost;Database=TTDS;User Id=sa;Password=sa;";
public static List<T> ExecuteQuery<T>(string sql, object param = null, SqlConnection sqlConnection = null)
{
if (sqlConnection != null)
{
return sqlConnection.Query<T>(sql, param).ToList();
}
using (var tempSqlConnection = new SqlConnection(SqlConnectionString))
{
tempSqlConnection.Open();
return tempSqlConnection.Query<T>(sql, param).ToList();
}
}
public static List<dynamic> ExecuteQuery<T1, T2>(string sql, object param = null, SqlConnection sqlConnection = null)
{
if (sqlConnection != null)
{
return MultiQuery<T1, T2>(sqlConnection, sql, param);
}
using (var tempSqlConnection = new SqlConnection(SqlConnectionString))
{
return MultiQuery<T1, T2>(tempSqlConnection, sql, param);
}
}
public static List<dynamic> ExecuteQuery<T1, T2, T3>(string sql, object param = null,
SqlConnection sqlConnection = null)
{
if (sqlConnection != null)
{
return MultiQuery<T1, T2, T3>(sqlConnection, sql, param);
}
using (var tempSqlConnection = new SqlConnection(SqlConnectionString))
{
return MultiQuery<T1, T2, T3>(tempSqlConnection, sql, param);
}
}
private static List<dynamic> MultiQuery<T1, T2>(SqlConnection sqlConnection, string sql, object param = null)
{
var rv = new List<dynamic>();
using (var grid = sqlConnection.QueryMultiple(sql, param))
{
rv.Add(grid.Read<T1>().ToList());
rv.Add(grid.Read<T2>().ToList());
}
return rv;
}
private static List<dynamic> MultiQuery<T1, T2, T3>(SqlConnection sqlConnection, string sql, object param = null)
{
var rv = new List<dynamic>();
using (var grid = sqlConnection.QueryMultiple(sql, param))
{
rv.Add(grid.Read<T1>().ToList());
rv.Add(grid.Read<T2>().ToList());
rv.Add(grid.Read<T3>().ToList());
}
return rv;
}
public static void ExecuteNonQuery(SqlConnection sqlConnection, string sql, object param, int? timeout = null)
{
if (sqlConnection != null)
{
sqlConnection.Execute(sql, param, commandTimeout: timeout);
}
else
{
using (var tempSqlConnection = new SqlConnection(SqlConnectionString))
{
tempSqlConnection.Open();
tempSqlConnection.Execute(sql, param, commandTimeout: timeout);
}
}
}
}
}
Here is some untested code that demonstrates a couple of ideas that I had.
While "using" is pretty awesome, you can pare down your code some if you optionally create the connection first and then, if necessary, dispose the sqlConnection in a finally block.
If you return a Tuple<List<T>,List<U>,List<V>> you can have strongly typed return values that you can easily use
If you call your most complex function from those that are less complex, you can minimize your duplicated code.
public static class SqlWrapper
{
private const string SqlConnectionString = "Server=localhost;Database=TTDS;User Id=sa;Password=sa;";
private class NoResult { }
public static List<T1> ExecuteQuery<T1>(string sql, object param = null, SqlConnection sqlConnection = null)
{
return ExecuteQuery<T1, NoResult, NoResult>(sql, param, sqlConnection).Item1;
}
public static Tuple<List<T1>, List<T2>> ExecuteQuery<T1, T2>(string sql, object param = null, SqlConnection sqlConnection = null)
{
var result = ExecuteQuery<T1, T2, NoResult>(sql, param, sqlConnection);
return Tuple.Create(result.Item1, result.Item2);
}
public static Tuple<List<T1>, List<T2>, List<T3>> ExecuteQuery<T1, T2, T3>(string sql, object param = null, SqlConnection sqlConnection = null)
{
List<T1> list1;
List<T2> list2 = null;
List<T3> list3 = null;
bool needsDisposed = false;
if (sqlConnection == null)
{
sqlConnection = new SqlConnection(SqlConnectionString);
sqlConnection.Open();
needsDisposed = true;
}
try
{
using (var grid = sqlConnection.QueryMultiple(sql, param))
{
list1 = grid.Read<T1>().ToList();
if (typeof(T2) != typeof(NoResult))
{
list2 = grid.Read<T2>().ToList();
}
if (typeof(T3) != typeof(NoResult))
{
list3 = grid.Read<T3>().ToList();
}
return Tuple.Create(list1, list2, list3);
}
}
finally { if (needsDisposed) sqlConnection.Dispose(); }
}
public static void ExecuteNonQuery(SqlConnection sqlConnection, string sql, object param, int? timeout = null)
{
bool needsDisposed = false;
if (sqlConnection == null)
{
sqlConnection = new SqlConnection(SqlConnectionString);
sqlConnection.Open();
needsDisposed = true;
}
try { sqlConnection.Execute(sql, param, commandTimeout: timeout); }
finally { if (needsDisposed) sqlConnection.Dispose(); }
}
}
Is it possible to make xUnit test work when you don't specify optional parameter values in InlineDataAttribute?
Example:
[Theory]
[InlineData(1, true)] // works
[InlineData(2)] // error
void Test(int num, bool fast=true){}
Yes it is. There are many ways to do it by redefining some original xunit attributes.
The following code is one of them, which would give you some idea.
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
public class OptionalTheoryAttribute : TheoryAttribute
{
protected override IEnumerable<ITestCommand> EnumerateTestCommands(IMethodInfo method)
{
var result = (List<ITestCommand>)base.EnumerateTestCommands(method);
try
{
return TransferToSupportOptional(result, method);
}
catch (Exception ex)
{
result.Clear();
result.Add(new LambdaTestCommand(method, () =>
{
throw new InvalidOperationException(
String.Format("An exception was thrown while getting data for theory {0}.{1}:\r\n{2}",
method.TypeName, method.Name, ex)
);
}));
}
return result;
}
private static IEnumerable<ITestCommand> TransferToSupportOptional(
IEnumerable<ITestCommand> testCommands, IMethodInfo method)
{
var parameterInfos = method.MethodInfo.GetParameters();
testCommands.OfType<TheoryCommand>().ToList().ForEach(
testCommand => typeof(TheoryCommand)
.GetProperty("Parameters")
.SetValue(testCommand, GetParameterValues(testCommand, parameterInfos)));
return testCommands;
}
private static object[] GetParameterValues(TheoryCommand testCommand, ParameterInfo[] parameterInfos)
{
var specifiedValues = testCommand.Parameters;
var optionalValues = GetOptionalValues(testCommand, parameterInfos);
return specifiedValues.Concat(optionalValues).ToArray();
}
private static IEnumerable<object> GetOptionalValues(TheoryCommand command, ParameterInfo[] parameterInfos)
{
return Enumerable.Range(command.Parameters.Length, parameterInfos.Length - command.Parameters.Length)
.ToList().Select(i =>
{
EnsureIsOptional(parameterInfos[i]);
return Type.Missing;
});
}
private static void EnsureIsOptional(ParameterInfo parameterInfo)
{
if (!parameterInfo.IsOptional)
{
throw new ArgumentException(string.Format(
"The parameter '{0}' should be optional or specified from data attribute.",
parameterInfo));
}
}
}
internal class LambdaTestCommand : TestCommand
{
private readonly Assert.ThrowsDelegate lambda;
public LambdaTestCommand(IMethodInfo method, Assert.ThrowsDelegate lambda)
: base(method, null, 0)
{
this.lambda = lambda;
}
public override bool ShouldCreateInstance
{
get
{
return false;
}
}
public override MethodResult Execute(object testClass)
{
try
{
lambda();
return new PassedResult(testMethod, DisplayName);
}
catch (Exception ex)
{
return new FailedResult(testMethod, ex, DisplayName);
}
}
}
public class OptionalTheoryTest
{
[OptionalTheory]
[InlineData(1)]
[InlineData(1, true)]
public void TestMethod(int num, bool fast = true)
{
// Arrange
// Act
// Assert
Assert.Equal(1, num);
Assert.True(fast);
}
}