自己写一个类似Dapper的ORM框架(c#)

Dapper就是一堆Connection的扩展方法,我们也用相同的方法实现,为了练习反射写的,原创~

使用技术:泛型、反射、表达式树...

客户端调用:

        static void Main(string[] args)
        {
            var connection = new SqlConnection("Data Source=.;User Id=sa;Password=123456;Database=fanDB;");
            //
            connection.Insert(new Person() { Name = "fan11", Age = 1 });
            connection.Insert(new List<Person> {
                new Person() { Name = "fan432", Age = 24 },
                new Person() { Name = "fan", Age = 4 }
            });
            //
            connection.Delete<Person>(5);
            connection.Delete(new Person() { ID = 6 });
            //
            connection.Update(new Person() { ID = 17, Name = "fanfan", Age = 18 });
            //
            var list = connection.Select<Person>(p => p.Name == "fan" || p.Name.Contains("fan1") || p.Name.StartsWith("fan") || p.Name.EndsWith("fan") && p.Age > 3);

            Console.ReadKey();
        }

ORM:

    public static class ORM
    {
        private const string ID_NAME = "ID";
        private const string INSERT_SQL = "INSERT INTO @TABLE_NAME(@COLUMNS) VALUES(@VALUES)";
        private const string SELECT_SQL = "SELECT * FROM @TABLE_NAME WHERE @WHERE";
        private const string DELETE_SQL = "DELETE FROM @TABLE_NAME WHERE @WHERE";
        private const string UPDATE_SQL = "UPDATE @TABLE_NAME SET @UPDATE_COLUMNS WHERE @WHERE";
        private static readonly ConcurrentDictionary<Type, PropertyInfo[]> PROPERTIES_CACHE = new System.Collections.Concurrent.ConcurrentDictionary<Type, PropertyInfo[]>();
        private static readonly WhereBuilder WHERE_BUILDER = null;//通过Expression生成where
        static ORM()
        {
            WHERE_BUILDER = new WhereBuilder('[', ']');
        }
        public static int Insert<T>(this SqlConnection connection, T entity)
        {
            int result = 0;
            var t = typeof(T);
            var tableName = t.Name;
            var columnInfoList = GetColumnInfos(entity);
            var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME);
            var columnNames = excludeIDColumns.Select(c => c.Name);
            var columnParameterNames = excludeIDColumns.Select(c => "@" + c.Name);
            string sql = INSERT_SQL.Replace("@TABLE_NAME", tableName)
                .Replace("@COLUMNS", string.Join(',', columnNames))
                .Replace("@VALUES", string.Join(',', columnParameterNames));

            SqlParameter[] paras = excludeIDColumns.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray();
            OpenConnection(connection);
            using (var command = connection.CreateCommand())
            {
                command.CommandType = CommandType.Text;
                command.CommandText = sql;
                command.Parameters.AddRange(paras);
                result = command.ExecuteNonQuery();
            }
            return result;
        }
        public static int Insert<T>(this SqlConnection connection, List<T> list)
        {
            int result = 0;
            foreach (var entity in list)
            {
                result += connection.Insert(entity);
            }
            return result;
        }

        public static List<T> Select<T>(this SqlConnection connection, Expression<Func<T, bool>> whereExp) where T : new()
        {
            List<T> list = new List<T>();
            var t = typeof(T);
            var tableName = t.Name;
            var wherePart = WHERE_BUILDER.ToSql<T>(whereExp);
            var whereParameter = wherePart.Parameters;
            var paras = whereParameter.Select(p => new SqlParameter(p.Key, p.Value)).ToArray();
            string sql = SELECT_SQL.Replace("@TABLE_NAME", tableName)
                        .Replace("@WHERE", wherePart.Sql);
            OpenConnection(connection);


            using (var command = connection.CreateCommand())
            {
                command.CommandType = CommandType.Text;
                command.CommandText = sql;
                command.Parameters.AddRange(paras);
                using (var reader = command.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        list.Add(ReaderToEntity<T>(reader));
                    }
                }
            }
            return list;

        }

        public static int Delete<T>(this SqlConnection connection, int ID)
        {
            int result = 0;
            var t = typeof(T);
            var tableName = t.Name;
            string sql = DELETE_SQL
                .Replace("@TABLE_NAME", tableName)
                .Replace("@WHERE", $"{ID_NAME}=@{ID_NAME}");
            SqlParameter[] paras = new SqlParameter[] { new SqlParameter("@" + ID_NAME, ID) };
            OpenConnection(connection);
            using (var command = connection.CreateCommand())
            {
                command.CommandType = CommandType.Text;
                command.CommandText = sql;
                command.Parameters.AddRange(paras);
                result = command.ExecuteNonQuery();
            }
            return result;
        }
        public static int Delete<T>(this SqlConnection connection, T entity)
        {
            var IDProperty = entity.GetType().GetProperty(ID_NAME);
            int ID = (int)IDProperty.GetValue(entity);
            return connection.Delete<T>(ID);
        }

        public static int Update<T>(this SqlConnection connection, T entity)
        {
            int result = 0;
            var t = typeof(T);
            var tableName = t.Name;
            var columnInfoList = GetColumnInfos(entity);
            var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME);
            var columnNames = excludeIDColumns.Select(c => c.Name);
            var columnParameters = excludeIDColumns.Select(c => c.Name + "=@" + c.Name);
            string sql = UPDATE_SQL.Replace("@TABLE_NAME", tableName)
                .Replace("@UPDATE_COLUMNS", string.Join(',', columnParameters))
                .Replace("@WHERE", $"{ID_NAME}=@ID");
            SqlParameter[] paras = columnInfoList.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray();

            OpenConnection(connection);
            using (var command = connection.CreateCommand())
            {
                command.CommandType = CommandType.Text;
                command.CommandText = sql;
                command.Parameters.AddRange(paras);
                result = command.ExecuteNonQuery();
            }
            return result;
        }

        private static T ReaderToEntity<T>(SqlDataReader reader) where T : new()
        {
            var entity = Activator.CreateInstance(typeof(T));

            var propertyInfos = GetPropertys<T>();
            foreach (var propertyInfo in propertyInfos)
            {
                var value = reader[propertyInfo.Name];
                propertyInfo.SetValue(entity, value);
            }
            return (T)entity;
        }
        private static PropertyInfo[] GetPropertys<T>()
        {
            return PROPERTIES_CACHE.GetOrAdd(typeof(T), t =>
            {
                return t.GetProperties();
            });
        }
        private static List<ColumnInfo> GetColumnInfos<T>(T entity)
        {
            var t = entity.GetType();

            var columnInfos = new List<ColumnInfo>();
            var properties = GetPropertys<T>();
            for (int i = 0; i < properties.Length; i++)
            {
                var prop = properties[i];
                columnInfos.Add(new ColumnInfo(prop.Name, prop.PropertyType.FullName, prop.GetValue(entity)));
            }
            return columnInfos;
        }
        private static DbType GetDbType(string typeName)
        {
            DbType type = DbType.String;
            switch (typeName)
            {
                case "System.String":
                    type = DbType.String; break;
                case "System.Int32":
                    type = DbType.Int32; break;
                case "System.Decimal":
                    type = DbType.Decimal;break;
                    //其他类型自己扩展,我就不加了 Guid、DateTime...
            }
            
            return type;
        }
        private static void OpenConnection(IDbConnection connection)
        {
            if (connection.State != ConnectionState.Open)
            {
                connection.Open();
            }
        }
    }
    public class ColumnInfo
    {
        public ColumnInfo(string name, string typeName, object value)
        {
            this.Name = name;
            this.TypeName = typeName;
            this.Value = value;
        }
        public string Name { get; set; }
        public string TypeName { get; set; }
        public object Value { get; set; }
    }

WhereBuilder:将表达式树转成where子句(从第三方扒下来的)

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;

/// <summary>
/// 生成Where条件的SQL语句
/// Generating SQL from expression trees
/// </summary>
public class WhereBuilder
{
    private readonly char _columnBeginChar = '[';
    private readonly char _columnEndChar = ']';
    private System.Collections.ObjectModel.ReadOnlyCollection<ParameterExpression> expressParameterNameCollection;

    public WhereBuilder(char columnChar = '`')
    {
        this._columnBeginChar = this._columnEndChar = columnChar;
    }

    public WhereBuilder(char columnBeginChar = '[', char columnEndChar = ']')
    {
        this._columnBeginChar = columnBeginChar;
        this._columnEndChar = columnEndChar;
    }

    /// <summary>
    /// LINQ转SQL
    /// </summary>
    /// <typeparam name="T"></typeparam>
    /// <param name="expression"></param>
    /// <returns></returns>
    public WherePart ToSql<T>(Expression<Func<T, bool>> expression)
    {
        var i = 1;
        if (expression.Parameters.Count > 0)
        {
            this.expressParameterNameCollection = expression.Parameters;
        }
        return Recurse(ref i, expression.Body, isUnary: true);
    }

    /// <summary>
    /// LINQ转SQL
    /// </summary>
    /// <typeparam name="T"></typeparam>
    /// <param name="i">种子值</param>
    /// <param name="expression"></param>
    /// <returns></returns>
    public WherePart ToSql<T>(ref int i, Expression<Func<T, bool>> expression)
    {
        if (expression.Parameters.Count > 0)
        {
            this.expressParameterNameCollection = expression.Parameters;
        }
        return Recurse(ref i, expression.Body, isUnary: true);
    }

    /// <summary>
    /// LINQ转SQL
    /// </summary>
    /// <param name="i">种子值</param>
    /// <param name="expression"></param>
    /// <param name="isUnary"></param>
    /// <param name="prefix"></param>
    /// <param name="postfix"></param>
    /// <returns></returns>
    private WherePart Recurse(ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null)
    {
        //运算符表达式
        if (expression is UnaryExpression)
        {
            var unary = (UnaryExpression)expression;
            //示例:m.birthday=DateTime.Now
            if (unary.NodeType == ExpressionType.Convert)
            {
                var value = GetValue(expression);
                if (value is string)
                {
                    value = prefix + (string)value + postfix;
                }
                return WherePart.IsParameter(i++, value);
            }
            else
            {
                //示例:m.Birthday>'2018-10-31'
                return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true));
            }
        }
        if (expression is BinaryExpression)
        {
            var body = (BinaryExpression)expression;
            return WherePart.Concat(Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right));
        }
        //常量值表达式
        //示例右侧表达式:m.ID=123;
        if (expression is ConstantExpression)
        {
            var constant = (ConstantExpression)expression;
            var value = constant.Value;
            if (value is int)
            {
                return WherePart.IsSql(value.ToString());
            }
            if (value is string)
            {
                value = prefix + (string)value + postfix;
            }
            if (value is bool && isUnary)
            {
                return WherePart.Concat(WherePart.IsParameter(i++, value), "=", WherePart.IsSql("1"));
            }
            return WherePart.IsParameter(i++, value);
        }
        //成员表达式
        if (expression is MemberExpression)
        {
            var member = (MemberExpression)expression;
            var memberExpress = member.Expression;
            bool isContainsParameterExpress = false;
            this.IsContainsParameterExpress(member, ref isContainsParameterExpress);
            if (member.Member is PropertyInfo && isContainsParameterExpress)
            {
                var property = (PropertyInfo)member.Member;
                //var colName = _tableDef.GetColumnNameFor(property.Name);
                var colName = property.Name;
                if (isUnary && member.Type == typeof(bool))
                {
                    return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++, true));
                }
                return WherePart.IsSql(string.Format("{0}{1}{2}", this._columnBeginChar, colName, this._columnEndChar));
            }
            if (member.Member is FieldInfo || !isContainsParameterExpress)
            {
                var value = GetValue(member);
                if (value is string)
                {
                    value = prefix + (string)value + postfix;
                }
                return WherePart.IsParameter(i++, value);
            }
            throw new Exception($"Expression does not refer to a property or field: {expression}");
        }
        //方法表达式
        if (expression is MethodCallExpression)
        {
            var methodCall = (MethodCallExpression)expression;
            //属性表达式中的参数表达式是否是表达式参数集合中的实例(或者表达式中包含的其他表达式中的参数表达式)
            bool isContainsParameterExpress = false;
            this.IsContainsParameterExpress(methodCall, ref isContainsParameterExpress);
            if (isContainsParameterExpress)
            {
                // LIKE queries:
                if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
                {
                    return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%", postfix: "%"));
                }
                if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) }))
                {
                    return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], postfix: "%"));
                }
                if (methodCall.Method == typeof(string).GetMethod("EndsWith", new[] { typeof(string) }))
                {
                    return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%"));
                }
                // IN queries:
                if (methodCall.Method.Name == "Contains")
                {
                    Expression collection;
                    Expression property;
                    if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 2)
                    {
                        collection = methodCall.Arguments[0];
                        property = methodCall.Arguments[1];
                    }
                    else if (!methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 1)
                    {
                        collection = methodCall.Object;
                        property = methodCall.Arguments[0];
                    }
                    else
                    {
                        throw new Exception("Unsupported method call: " + methodCall.Method.Name);
                    }
                    var values = (IEnumerable)GetValue(collection);
                    return WherePart.Concat(Recurse(ref i, property), "IN", WherePart.IsCollection(ref i, values));
                }
            }
            else
            {
                var value = GetValue(expression);
                if (value is string)
                {
                    value = prefix + (string)value + postfix;
                }
                return WherePart.IsParameter(i++, value);
            }

            throw new Exception("Unsupported method call: " + methodCall.Method.Name);
        }
        //New表达式
        if (expression is NewExpression)
        {
            var member = (NewExpression)expression;
            var value = GetValue(member);
            if (value is string)
            {
                value = prefix + (string)value + postfix;
            }
            return WherePart.IsParameter(i++, value);
        }
        throw new Exception("Unsupported expression: " + expression.GetType().Name);
    }
    /// <summary>
    /// 判断表达式内部是否含有变量M
    /// </summary>
    /// <param name="expression">表达式</param>
    /// <returns></returns>
    private void IsContainsParameterExpress(Expression expression, ref bool result)
    {
        if (this.expressParameterNameCollection != null && this.expressParameterNameCollection.Count > 0 && expression != null)
        {
            if (expression is MemberExpression)
            {
                if (this.expressParameterNameCollection.Contains(((MemberExpression)expression).Expression))
                {
                    result = true;
                }
            }
            else if (expression is MethodCallExpression)
            {
                MethodCallExpression methodCallExpression = (MethodCallExpression)expression;

                if (methodCallExpression.Object != null)
                {
                    if (methodCallExpression.Object is MethodCallExpression)
                    {
                        //判断示例1:m.ID.ToString().Contains("123")
                        this.IsContainsParameterExpress(methodCallExpression.Object, ref result);
                    }
                    else if (methodCallExpression.Object is MemberExpression)
                    {
                        //判断示例2:m.ID.Contains(123)
                        MemberExpression MemberExpression = (MemberExpression)methodCallExpression.Object;
                        if (MemberExpression.Expression != null && this.expressParameterNameCollection.Contains(MemberExpression.Expression))
                        {
                            result = true;
                        }
                    }
                }
                //判断示例3: int[] ids=new ids[]{1,2,3};  ids.Contains(m.ID)
                if (result == false && methodCallExpression.Arguments != null && methodCallExpression.Arguments.Count > 0)
                {
                    foreach (Expression express in methodCallExpression.Arguments)
                    {
                        if (express is MemberExpression || express is MethodCallExpression)
                        {
                            this.IsContainsParameterExpress(express, ref result);
                        }
                        else if (this.expressParameterNameCollection.Contains(express))
                        {
                            result = true;
                            break;
                        }
                    }
                }
            }
        }
    }

    private static object GetValue(Expression member)
    {
        // source: http://*.com/a/2616980/291955
        var objectMember = Expression.Convert(member, typeof(object));
        var getterLambda = Expression.Lambda<Func<object>>(objectMember);
        var getter = getterLambda.Compile();
        return getter();
    }

    private static string NodeTypeToString(ExpressionType nodeType)
    {
        switch (nodeType)
        {
            case ExpressionType.Add:
                return "+";
            case ExpressionType.And:
                return "&";
            case ExpressionType.AndAlso:
                return "AND";
            case ExpressionType.Divide:
                return "/";
            case ExpressionType.Equal:
                return "=";
            case ExpressionType.ExclusiveOr:
                return "^";
            case ExpressionType.GreaterThan:
                return ">";
            case ExpressionType.GreaterThanOrEqual:
                return ">=";
            case ExpressionType.LessThan:
                return "<";
            case ExpressionType.LessThanOrEqual:
                return "<=";
            case ExpressionType.Modulo:
                return "%";
            case ExpressionType.Multiply:
                return "*";
            case ExpressionType.Negate:
                return "-";
            case ExpressionType.Not:
                return "NOT";
            case ExpressionType.NotEqual:
                return "<>";
            case ExpressionType.Or:
                return "|";
            case ExpressionType.OrElse:
                return "OR";
            case ExpressionType.Subtract:
                return "-";
        }
        throw new Exception($"Unsupported node type: {nodeType}");
    }
}

public class WherePart
{
    /// <summary>
    /// 含有参数变量的SQL语句
    /// </summary>
    public string Sql { get; set; }
    /// <summary>
    /// SQL语句中的参数变量
    /// </summary>
    public Dictionary<string, object> Parameters { get; set; } = new Dictionary<string, object>();

    public static WherePart IsSql(string sql)
    {
        return new WherePart()
        {
            Parameters = new Dictionary<string, object>(),
            Sql = sql
        };
    }

    public static WherePart IsParameter(int count, object value)
    {
        return new WherePart()
        {
            Parameters = { { count.ToString(), value } },
            Sql = $"@{count}"
        };
    }

    public static WherePart IsCollection(ref int countStart, IEnumerable values)
    {
        var parameters = new Dictionary<string, object>();
        var sql = new StringBuilder("(");
        foreach (var value in values)
        {
            parameters.Add((countStart).ToString(), value);
            sql.Append($"@{countStart},");
            countStart++;
        }
        if (sql.Length == 1)
        {
            sql.Append("null,");
        }
        sql[sql.Length - 1] = ')';
        return new WherePart()
        {
            Parameters = parameters,
            Sql = sql.ToString()
        };
    }

    public static WherePart Concat(string @operator, WherePart operand)
    {
        return new WherePart()
        {
            Parameters = operand.Parameters,
            Sql = $"({@operator} {operand.Sql})"
        };
    }

    public static WherePart Concat(WherePart left, string @operator, WherePart right)
    {
        return new WherePart()
        {
            Parameters = left.Parameters.Union(right.Parameters).ToDictionary(kvp => kvp.Key, kvp => kvp.Value),
            Sql = $"({left.Sql} {@operator} {right.Sql})"
        };
    }
}
View Code