Modifying LINQ To SQL command text

UPDATE 2 – Improved version that correctly handles sub-queries can be found here.

UPDATE – To see an example use for this code, see this post where I discuss supporting native XQuery using LINQ to SQL.

HACK WARNING – This example calls private framework methods through Reflection.  It’s bad practice in almost every way.  Also, it’s unlikely to work in a partial trust environment.  Use at your own risk Smile

Here’s a bit of code that’ll let you modify any command text LINQ To SQL generates before it hits the database.  It’s always been possible to call DataContext.GetCommand and use modified SQL in a call to DataContext.Translate or DataContext.ExecuteQuery, but that’s pretty restrictive in the complexity of query you can use.  Also, if LINQ To SQL decides it needs to batch multiple SELECT statements for the desired results, GetCommand can only give you the first one.

The first thing to do is replace the IProvider implementation used by your DataContext.  Unfortunately, DataContext’s Provider property is both private and read-only.  Also, the IProvider interface itself is private.  This is bad in every way, and we absolutely shouldn’t try to hack our way through that… but we will.

Inspired by a 2008 post by Matt Warren, I made the following DataContextInterceptor helper.  It intercepts calls to the IProvider’s Compile and Execute methods.  Rather than rewrite most of the SqlProvider code, it calls the same private methods it normally would, albeit through reflection.  The only other thing it does is call the ModifyCommandDelegate delegate (assigned by your code) with the command text and parameters.

Although it’s not the cleanest solution, it’s pretty much transparent to LINQ To SQL.  You can use complex queries (even compiled ones via CompiledQuery.Compile) and they work just fine:

    public class DataContextInterceptor
    {
        public delegate string ModifyCommandDelegate( string commandText, IDictionary<string, object> parameters );

        private DataContext dc;
        private object oldProvider;
        private Type providerType;
        private ModifyCommandDelegate modifyCommand;

        public static TDataContext Intercept<TDataContext>( TDataContext dc, ModifyCommandDelegate modifyCommand )
            where TDataContext : DataContext
        {
            new DataContextInterceptor( dc, modifyCommand );

            return dc;
        }

        public DataContextInterceptor( DataContext dc, ModifyCommandDelegate modifyCommand )
        {
            this.dc = dc;
            this.modifyCommand = modifyCommand;

            FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
            var existingProvider = providerField.GetValue( dc );

            if ( existingProvider is SqlProvider )
            {
                oldProvider = existingProvider;

                var proxy = new ProviderProxy( this, oldProvider ).GetTransparentProxy();

                providerField.SetValue( dc, proxy );
            }
        }

        public static MethodInfo GetMethod( Type type, string methodName, params object[] args )
        {
            var hasNullArgs = args.Any( a => a == null );

            var method = hasNullArgs
                ? type.GetMethods( BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic )
                    .FirstOrDefault( m => m.Name == methodName )
                : null;

            if ( method == null && !hasNullArgs )
            {
                method = type.GetMethod(
                    methodName,
                    BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                    null,
                    args.Select( a => a.GetType() ).ToArray(),
                    null );
            }

            return method;
        }

        private object Invoke( object instance, string methodName, params object[] args )
        {
            var type = instance.GetType();
            var method = GetMethod( type, methodName, args );

            return ( method != null )
                ? method.Invoke( instance, args )
                : null;
        }

        private object Invoke( Type type, string methodName, params object[] args )
        {
            var method = type.GetMethod( methodName, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic );

            return method.Invoke( null, args );
        }

        protected object CompileImpl( Expression query )
        {
            var assembly = typeof( SqlProvider ).Assembly;

            /*
            this.CheckDispose();
            this.CheckInitialized();
            if ( query == null )
            {
                throw Error.ArgumentNull( "query" );
            }
            */

            // this.InitializeProviderMode();
            Invoke( oldProvider, "InitializeProviderMode" );

            // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
            var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

            // QueryInfo[] queries = this.BuildQuery( query, annotations );
            var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

            var info = ModifyQueries( (IEnumerable)queries );

            // this.CheckSqlCompatibility( queries, annotations );
            Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

            LambdaExpression expression = query as LambdaExpression;

            if ( expression != null )
            {
                query = expression.Body;
            }

            // IObjectReaderFactory readerFactory = null;
            object readerFactory = null;

            // ICompiledSubQuery[] subQueries = null;
            object subQueries = null;

            // QueryInfo info = queries[ queries.Length - 1 ];
            // info defined above

            var resultShape = (int)Invoke( info, "get_ResultShape" );

            // if ( info.ResultShape == ResultShape.Singleton )
            if ( resultShape == 1 /* Singleton */ )
            {
                // subQueries = this.CompileSubQueries( info.Query );
                subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                // readerFactory = this.GetReaderFactory( info.Query, info.ResultType );
                readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
            }
            // else if ( info.ResultShape == ResultShape.Sequence )
            else if ( resultShape == 2 /* Sequence */ )
            {
                // subQueries = this.CompileSubQueries( info.Query );
                subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) );
                var resultType = Invoke( info, "get_ResultType" );
                var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
            }

            FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
            providerField.SetValue( dc, oldProvider );

            // return new CompiledQuery( this, query, queries, readerFactory, subQueries );
            var compiledQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider+CompiledQuery" );

            return Activator.CreateInstance(
                compiledQueryType,
                BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                null,
                new object[] { oldProvider, query, queries, readerFactory, subQueries },
                null );
        }

        protected internal virtual IExecuteResult ExecuteImpl( Expression query )
        {
            var assembly = typeof( SqlProvider ).Assembly;

            /*
            this.CheckDispose();
            this.CheckInitialized();
            this.CheckNotDeleted();
            if (query == null)
            {
                throw Error.ArgumentNull("query");
            }
            */

            // this.InitializeProviderMode();
            Invoke( oldProvider, "InitializeProviderMode" );

            // query = Funcletizer.Funcletize(query);
            var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" );
            query = (Expression)Invoke( funcletizerType, "Funcletize", query );

            // if ( this.EnableCacheLookup )
            if ( (bool)Invoke( oldProvider, "get_EnableCacheLookup" ) )
            {
                // IExecuteResult cachedResult = this.GetCachedResult(query);
                object cachedResult = Invoke( oldProvider, "GetCachedResult", query );

                if ( cachedResult != null )
                {
                    // return cachedResult;
                    return (IExecuteResult)cachedResult;
                }
            }

            // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
            var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

            // QueryInfo[] queries = this.BuildQuery(query, annotations);
            var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

            var info = ModifyQueries( (IEnumerable)queries );

            // this.CheckSqlCompatibility(queries, annotations);
            Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

            LambdaExpression expression = query as LambdaExpression;

            if ( expression != null )
            {
                query = expression.Body;
            }

            // IObjectReaderFactory readerFactory = null;
            object readerFactory = null;

            // ICompiledSubQuery[] subQueries = null;
            object subQueries = null;

            // QueryInfo info = queries[queries.Length - 1];
            // info defined above

            var resultShape = (int)Invoke( info, "get_ResultShape" );

            // if (info.ResultShape == ResultShape.Singleton)
            if ( resultShape == 1 /* Singleton */ )
            {
                // subQueries = this.CompileSubQueries(info.Query);
                subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                // readerFactory = this.GetReaderFactory(info.Query, info.ResultType);
                readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
            }
            // else if (info.ResultShape == ResultShape.Sequence)
            else if ( resultShape == 2 /* Sequence */ )
            {
                // subQueries = this.CompileSubQueries(info.Query);
                subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                // readerFactory = this.GetReaderFactory(info.Query, TypeSystem.GetElementType(info.ResultType));
                var resultType = Invoke( info, "get_ResultType" );
                var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
            }

            // return this.ExecuteAll(query, queries, readerFactory, null, subQueries);
            return (IExecuteResult)Invoke( oldProvider, "ExecuteAll", query, queries, readerFactory, null, subQueries );
        }

        private object ModifyQueries( IEnumerable queries )
        {
            object lastQuery = null;

            foreach ( var q in queries )
            {
                lastQuery = q;

                var commandTextField = q.GetType().GetField( "commandText", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                var parametersField = q.GetType().GetField( "parameters", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );

                var commandText = (string)commandTextField.GetValue( q );
                var parameters = new Dictionary<string, object>();

                foreach ( var p in (IEnumerable)parametersField.GetValue( q ) )
                {
                    var name = (string)Invoke( Invoke( p, "get_Parameter" ), "get_Name" );
                    parameters[ name ] = Invoke( p, "get_Value" );
                }

                var modifiedCommandText = modifyCommand( commandText, parameters );

                commandTextField.SetValue( q, modifiedCommandText );
            }

            return lastQuery;
        }

        public class ProviderProxy : RealProxy, IRemotingTypeInfo
        {
            DataContextInterceptor extender;
            object oldProvider;

            internal ProviderProxy( DataContextInterceptor extender, object oldProvider )
                : base( typeof( ContextBoundObject ) )
            {
                this.extender = extender;
                this.oldProvider = oldProvider;
            }

            public override IMessage Invoke( IMessage msg )
            {
                if ( msg is IMethodCallMessage )
                {
                    IMethodCallMessage call = (IMethodCallMessage)msg;
                    MethodInfo mi = null;

                    if ( call.MethodBase.DeclaringType.Name == "IProvider" && call.MethodBase.DeclaringType.IsInterface )
                    {
                        extender.providerType = call.MethodBase.DeclaringType;

                        mi = DataContextInterceptor.GetMethod( typeof( DataContextInterceptor ), call.MethodBase.Name + "Impl", call.Args );

                        if ( mi == null && oldProvider != null )
                        {
                            mi = DataContextInterceptor.GetMethod( call.MethodBase.DeclaringType, call.MethodBase.Name, call.Args );

                            return new ReturnMessage( mi.Invoke( oldProvider, call.Args ), null, 0, null, call );
                        }

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( this.extender, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }
//                    else mi = typeof( DataContextInterceptor ).GetMethod( call.MethodBase.Name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                    else
                    {
                        mi = DataContextInterceptor.GetMethod( oldProvider.GetType(), call.MethodBase.Name, call.Args );

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( oldProvider, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }

                    throw new NotImplementedException(
                        string.Format( "Method not found: {0}( {1} )",
                            call.MethodBase.Name,
                            string.Join( ", ", call.Args.Select( a => Convert.ToString( a ) ) ) ) );
                }

                throw new NotImplementedException();
            }

            public bool CanCastTo( Type fromType, object o )
            {
                return true;
            }

            public string TypeName
            {
                get { return this.GetType().Name; }
                set { }
            }
        }
    }

Hope this helps!

4 Comments

  1. […] There are a couple of ways to do this, depending on the complexity / flexibility required for your query.  For simple queries you can use the DataContext.GetCommand method, modify the command, and get the results through DataContext.Translate or DataContext.ExecuteQuery.  For complex queries (particularly ones that need to shape the results) you’ll need to jump through a couple of Reflection hoops. […]

    Reply

Leave a comment