Last year I posted this article about modifying LINQ to SQL (L2S) command text.  It was slightly evil in that it called private methods inside L2S, and it did it through reflection.

I have an alternative version that does the same thing through a pre-compiled expression tree, which I’ll post soon.  Until then, here’s an update to the original reflection-based interceptor that properly handles sub-queries:

    public delegate string ModifyCommandDelegate( string commandText, IDictionary<string, object> parameters );

    public class ReflectionDataContextInterceptor
    {
        private DataContext dc;
        private object oldProvider;
        private Type providerType;
        private ModifyCommandDelegate modifyCommand;

        public static TDataContext Intercept<TDataContext>( TDataContext dc, ModifyCommandDelegate modifyCommand )
            where TDataContext : DataContext
        {
            new ReflectionDataContextInterceptor( dc, modifyCommand );

            return dc;
        }

        public ReflectionDataContextInterceptor( DataContext dc, ModifyCommandDelegate modifyCommand )
        {
            this.dc = dc;
            this.modifyCommand = modifyCommand;

            FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
            var existingProvider = providerField.GetValue( dc );

            if ( existingProvider is IProviderProxy )
            {
//                System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} already intercepted", dc.GetHashCode() ) );
            }
            else
            {
                oldProvider = existingProvider;

                var proxy = new ProviderProxy( this, oldProvider ).GetTransparentProxy();

                providerField.SetValue( dc, proxy );

//                System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} intercepted", dc.GetHashCode() ) );
            }
        }

        public static MethodInfo GetMethod( Type type, string methodName, params object[] args )
        {
            var hasNullArgs = args.Any( a => a == null );

            var method = hasNullArgs
                ? type.GetMethods( BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic )
                    .FirstOrDefault( m => m.Name == methodName )
                : null;

            if ( method == null && !hasNullArgs )
            {
                method = type.GetMethod(
                    methodName,
                    BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                    null,
                    args.Select( a => a.GetType() ).ToArray(),
                    null );
            }

            return method;
        }

        private object Invoke( object instance, string methodName, params object[] args )
        {
            var type = instance.GetType();
            var method = GetMethod( type, methodName, args );

            return ( method != null )
                ? method.Invoke( instance, args )
                : null;
        }

        private object Invoke( Type type, string methodName, params object[] args )
        {
            var method = type.GetMethod( methodName, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic );

            return method.Invoke( null, args );
        }

        protected object CompileImpl( Expression query )
        {
            try
            {
                var assembly = typeof( SqlProvider ).Assembly;

                /*
                this.CheckDispose();
                this.CheckInitialized();
                if ( query == null )
                {
                    throw Error.ArgumentNull( "query" );
                }
                */

                // this.InitializeProviderMode();
                Invoke( oldProvider, "InitializeProviderMode" );

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

                // QueryInfo[] queries = this.BuildQuery( query, annotations );
                var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

                var info = ModifyQueries( (IEnumerable)queries );

                // this.CheckSqlCompatibility( queries, annotations );
                Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

                LambdaExpression expression = query as LambdaExpression;

                if ( expression != null )
                {
                    query = expression.Body;
                }

                // IObjectReaderFactory readerFactory = null;
                object readerFactory = null;

                // ICompiledSubQuery[] subQueries = null;
                object subQueries = null;

                // QueryInfo info = queries[ queries.Length - 1 ];
                // info defined above

                var resultShape = (int)Invoke( info, "get_ResultShape" );

                // if ( info.ResultShape == ResultShape.Singleton )
                if ( resultShape == 1 /* Singleton */ )
                {
                    // subQueries = this.CompileSubQueries( info.Query );
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory( info.Query, info.ResultType );
                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
                }
                // else if ( info.ResultShape == ResultShape.Sequence )
                else if ( resultShape == 2 /* Sequence */ )
                {
                    // subQueries = this.CompileSubQueries( info.Query );
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) );
                    var resultType = Invoke( info, "get_ResultType" );
                    var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                    var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
                }

                FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
                providerField.SetValue( dc, oldProvider );

                //            System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} interceptor released (compiled query)", dc.GetHashCode() ) );

                // return new CompiledQuery( this, query, queries, readerFactory, subQueries );
                var compiledQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider+CompiledQuery" );

                return Activator.CreateInstance(
                    compiledQueryType,
                    BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                    null,
                    new object[] { oldProvider, query, queries, readerFactory, subQueries },
                    null );
            }
            catch ( TargetInvocationException ex )
            {
                throw ex.InnerException;
            }
        }

        protected internal virtual IExecuteResult ExecuteImpl( Expression query )
        {
            try
            {
                var assembly = typeof( SqlProvider ).Assembly;

                /*
                this.CheckDispose();
                this.CheckInitialized();
                this.CheckNotDeleted();
                if (query == null)
                {
                    throw Error.ArgumentNull("query");
                }
                */

                // this.InitializeProviderMode();
                Invoke( oldProvider, "InitializeProviderMode" );

                // query = Funcletizer.Funcletize(query);
                var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" );
                query = (Expression)Invoke( funcletizerType, "Funcletize", query );

                // if ( this.EnableCacheLookup )
                if ( (bool)Invoke( oldProvider, "get_EnableCacheLookup" ) )
                {
                    // IExecuteResult cachedResult = this.GetCachedResult(query);
                    object cachedResult = Invoke( oldProvider, "GetCachedResult", query );

                    if ( cachedResult != null )
                    {
                        // return cachedResult;
                        return (IExecuteResult)cachedResult;
                    }
                }

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

                // QueryInfo[] queries = this.BuildQuery(query, annotations);
                var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

                var info = ModifyQueries( (IEnumerable)queries );

                // this.CheckSqlCompatibility(queries, annotations);
                Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

                LambdaExpression expression = query as LambdaExpression;

                if ( expression != null )
                {
                    query = expression.Body;
                }

                // IObjectReaderFactory readerFactory = null;
                object readerFactory = null;

                // ICompiledSubQuery[] subQueries = null;
                object subQueries = null;

                // QueryInfo info = queries[queries.Length - 1];
                // info defined above

                var resultShape = (int)Invoke( info, "get_ResultShape" );

                // if (info.ResultShape == ResultShape.Singleton)
                if ( resultShape == 1 /* Singleton */ )
                {
                    // subQueries = this.CompileSubQueries(info.Query);
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory(info.Query, info.ResultType);
                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
                }
                // else if (info.ResultShape == ResultShape.Sequence)
                else if ( resultShape == 2 /* Sequence */ )
                {
                    // subQueries = this.CompileSubQueries(info.Query);
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory(info.Query, TypeSystem.GetElementType(info.ResultType));
                    var resultType = Invoke( info, "get_ResultType" );
                    var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                    var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
                }

                // return this.ExecuteAll(query, queries, readerFactory, null, subQueries);
                return (IExecuteResult)Invoke( oldProvider, "ExecuteAll", query, queries, readerFactory, null, subQueries );
            }
            catch ( TargetInvocationException ex )
            {
                throw ex.InnerException;
            }
        }

        private object ModifyQueries( IEnumerable queries )
        {
            object lastQuery = null;

            foreach ( var q in queries )
            {
                lastQuery = q;

                ModifyQuery( q );
            }

            return lastQuery;
        }

        private void ModifySubQueries( IEnumerable subQueries )
        {
            if ( subQueries == null ) return;

            foreach ( var sq in subQueries )
            {
                var queryInfoField = sq.GetType().GetField( "queryInfo", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                var queryInfo = queryInfoField.GetValue( sq );
                ModifyQuery( queryInfo );

                var subQueriesField = sq.GetType().GetField( "subQueries", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                var nestedSubQueries = subQueriesField.GetValue( sq );

                if ( nestedSubQueries != null ) ModifySubQueries( (IEnumerable)nestedSubQueries );
            }
        }

        private void ModifyQuery( object q /* QueryInfo */ )
        {
            var commandTextField = q.GetType().GetField( "commandText", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
            var parametersField = q.GetType().GetField( "parameters", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );

            var commandText = (string)commandTextField.GetValue( q );
            var parameterInfos = parametersField.GetValue( q );
            var parameters = new Dictionary<string, object>();

            foreach ( var p in (IEnumerable)parameterInfos )
            {
                var param = Invoke( p, "get_Parameter" );
                var name = (string)Invoke( param, "get_Name" );

                parameters[ name ] = Invoke( p, "get_Value" );
            }

            var modifiedCommandText = modifyCommand( commandText, parameters );

//            System.Diagnostics.Trace.WriteLine( modifiedCommandText + "\n---" );

            commandTextField.SetValue( q, modifiedCommandText );
        }
/*
        protected internal DbCommand GetCommandImpl( Expression query )
        {
            return (DbCommand)Invoke( oldProvider, "GetCommand", query );
        }

        protected internal string GetQueryTextImpl( Expression query )
        {
            return (string)Invoke( oldProvider, "GetQueryText", query );
        }
*/
        internal interface IProviderProxy
        {
            ReflectionDataContextInterceptor Interceptor { get; }
            object OldProvider { get; }
        }

        public class ProviderProxy : RealProxy, IRemotingTypeInfo, IProviderProxy
        {
            public ReflectionDataContextInterceptor Interceptor { get; private set; }
            public object OldProvider { get; private set; }

            internal ProviderProxy( ReflectionDataContextInterceptor extender, object oldProvider )
                : base( typeof( ContextBoundObject ) )
            {
                this.Interceptor = extender;
                this.OldProvider = oldProvider;
            }

            public override IMessage Invoke( IMessage msg )
            {
                if ( msg is IMethodCallMessage )
                {
                    IMethodCallMessage call = (IMethodCallMessage)msg;
                    MethodInfo mi = null;

                    if ( call.MethodBase.DeclaringType.Name == "IProvider" && call.MethodBase.DeclaringType.IsInterface )
                    {
                        Interceptor.providerType = call.MethodBase.DeclaringType;

                        mi = ReflectionDataContextInterceptor.GetMethod( typeof( ReflectionDataContextInterceptor ), call.MethodBase.Name + "Impl", call.Args );

                        if ( mi == null && OldProvider != null )
                        {
                            mi = ReflectionDataContextInterceptor.GetMethod( call.MethodBase.DeclaringType, call.MethodBase.Name, call.Args );

                            try
                            {
                                return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( this.Interceptor, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }
//                    else mi = typeof( ReflectionDataContextInterceptor ).GetMethod( call.MethodBase.Name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                    else
                    {
                        mi = ReflectionDataContextInterceptor.GetMethod( OldProvider.GetType(), call.MethodBase.Name, call.Args );

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }

                    throw new NotImplementedException(
                        string.Format( "Method not found: {0}( {1} )",
                            call.MethodBase.Name,
                            string.Join( ", ", call.Args.Select( a => Convert.ToString( a ) ) ) ) );
                }

                throw new NotImplementedException();
            }

            public bool CanCastTo( Type fromType, object o )
            {
                return true;
            }

            public string TypeName
            {
                get { return this.GetType().Name; }
                set { }
            }
        }
    }

It’s still evil, but works great Smile

About these ads