Last year I posted this article about modifying LINQ to SQL (L2S) command text. It was slightly evil in that it called private methods inside L2S, and it did it through reflection.
I have an alternative version that does the same thing through a pre-compiled expression tree, which I’ll post soon. Until then, here’s an update to the original reflection-based interceptor that properly handles sub-queries:
public delegate string ModifyCommandDelegate( string commandText, IDictionary<string, object> parameters ); public class ReflectionDataContextInterceptor { private DataContext dc; private object oldProvider; private Type providerType; private ModifyCommandDelegate modifyCommand; public static TDataContext Intercept<TDataContext>( TDataContext dc, ModifyCommandDelegate modifyCommand ) where TDataContext : DataContext { new ReflectionDataContextInterceptor( dc, modifyCommand ); return dc; } public ReflectionDataContextInterceptor( DataContext dc, ModifyCommandDelegate modifyCommand ) { this.dc = dc; this.modifyCommand = modifyCommand; FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic ); var existingProvider = providerField.GetValue( dc ); if ( existingProvider is IProviderProxy ) { // System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} already intercepted", dc.GetHashCode() ) ); } else { oldProvider = existingProvider; var proxy = new ProviderProxy( this, oldProvider ).GetTransparentProxy(); providerField.SetValue( dc, proxy ); // System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} intercepted", dc.GetHashCode() ) ); } } public static MethodInfo GetMethod( Type type, string methodName, params object[] args ) { var hasNullArgs = args.Any( a => a == null ); var method = hasNullArgs ? type.GetMethods( BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ) .FirstOrDefault( m => m.Name == methodName ) : null; if ( method == null && !hasNullArgs ) { method = type.GetMethod( methodName, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, args.Select( a => a.GetType() ).ToArray(), null ); } return method; } private object Invoke( object instance, string methodName, params object[] args ) { var type = instance.GetType(); var method = GetMethod( type, methodName, args ); return ( method != null ) ? method.Invoke( instance, args ) : null; } private object Invoke( Type type, string methodName, params object[] args ) { var method = type.GetMethod( methodName, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic ); return method.Invoke( null, args ); } protected object CompileImpl( Expression query ) { try { var assembly = typeof( SqlProvider ).Assembly; /* this.CheckDispose(); this.CheckInitialized(); if ( query == null ) { throw Error.ArgumentNull( "query" ); } */ // this.InitializeProviderMode(); Invoke( oldProvider, "InitializeProviderMode" ); // SqlNodeAnnotations annotations = new SqlNodeAnnotations(); var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) ); // QueryInfo[] queries = this.BuildQuery( query, annotations ); var queries = Invoke( oldProvider, "BuildQuery", query, annotations ); var info = ModifyQueries( (IEnumerable)queries ); // this.CheckSqlCompatibility( queries, annotations ); Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations ); LambdaExpression expression = query as LambdaExpression; if ( expression != null ) { query = expression.Body; } // IObjectReaderFactory readerFactory = null; object readerFactory = null; // ICompiledSubQuery[] subQueries = null; object subQueries = null; // QueryInfo info = queries[ queries.Length - 1 ]; // info defined above var resultShape = (int)Invoke( info, "get_ResultShape" ); // if ( info.ResultShape == ResultShape.Singleton ) if ( resultShape == 1 /* Singleton */ ) { // subQueries = this.CompileSubQueries( info.Query ); subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) ); ModifySubQueries( (IEnumerable)subQueries ); // readerFactory = this.GetReaderFactory( info.Query, info.ResultType ); readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) ); } // else if ( info.ResultShape == ResultShape.Sequence ) else if ( resultShape == 2 /* Sequence */ ) { // subQueries = this.CompileSubQueries( info.Query ); subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) ); ModifySubQueries( (IEnumerable)subQueries ); // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) ); var resultType = Invoke( info, "get_ResultType" ); var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" ); var elementType = Invoke( typeSystemType, "GetElementType", resultType ); readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType ); } FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic ); providerField.SetValue( dc, oldProvider ); // System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} interceptor released (compiled query)", dc.GetHashCode() ) ); // return new CompiledQuery( this, query, queries, readerFactory, subQueries ); var compiledQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider+CompiledQuery" ); return Activator.CreateInstance( compiledQueryType, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, new object[] { oldProvider, query, queries, readerFactory, subQueries }, null ); } catch ( TargetInvocationException ex ) { throw ex.InnerException; } } protected internal virtual IExecuteResult ExecuteImpl( Expression query ) { try { var assembly = typeof( SqlProvider ).Assembly; /* this.CheckDispose(); this.CheckInitialized(); this.CheckNotDeleted(); if (query == null) { throw Error.ArgumentNull("query"); } */ // this.InitializeProviderMode(); Invoke( oldProvider, "InitializeProviderMode" ); // query = Funcletizer.Funcletize(query); var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" ); query = (Expression)Invoke( funcletizerType, "Funcletize", query ); // if ( this.EnableCacheLookup ) if ( (bool)Invoke( oldProvider, "get_EnableCacheLookup" ) ) { // IExecuteResult cachedResult = this.GetCachedResult(query); object cachedResult = Invoke( oldProvider, "GetCachedResult", query ); if ( cachedResult != null ) { // return cachedResult; return (IExecuteResult)cachedResult; } } // SqlNodeAnnotations annotations = new SqlNodeAnnotations(); var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) ); // QueryInfo[] queries = this.BuildQuery(query, annotations); var queries = Invoke( oldProvider, "BuildQuery", query, annotations ); var info = ModifyQueries( (IEnumerable)queries ); // this.CheckSqlCompatibility(queries, annotations); Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations ); LambdaExpression expression = query as LambdaExpression; if ( expression != null ) { query = expression.Body; } // IObjectReaderFactory readerFactory = null; object readerFactory = null; // ICompiledSubQuery[] subQueries = null; object subQueries = null; // QueryInfo info = queries[queries.Length - 1]; // info defined above var resultShape = (int)Invoke( info, "get_ResultShape" ); // if (info.ResultShape == ResultShape.Singleton) if ( resultShape == 1 /* Singleton */ ) { // subQueries = this.CompileSubQueries(info.Query); subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) ); ModifySubQueries( (IEnumerable)subQueries ); // readerFactory = this.GetReaderFactory(info.Query, info.ResultType); readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) ); } // else if (info.ResultShape == ResultShape.Sequence) else if ( resultShape == 2 /* Sequence */ ) { // subQueries = this.CompileSubQueries(info.Query); subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) ); ModifySubQueries( (IEnumerable)subQueries ); // readerFactory = this.GetReaderFactory(info.Query, TypeSystem.GetElementType(info.ResultType)); var resultType = Invoke( info, "get_ResultType" ); var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" ); var elementType = Invoke( typeSystemType, "GetElementType", resultType ); readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType ); } // return this.ExecuteAll(query, queries, readerFactory, null, subQueries); return (IExecuteResult)Invoke( oldProvider, "ExecuteAll", query, queries, readerFactory, null, subQueries ); } catch ( TargetInvocationException ex ) { throw ex.InnerException; } } private object ModifyQueries( IEnumerable queries ) { object lastQuery = null; foreach ( var q in queries ) { lastQuery = q; ModifyQuery( q ); } return lastQuery; } private void ModifySubQueries( IEnumerable subQueries ) { if ( subQueries == null ) return; foreach ( var sq in subQueries ) { var queryInfoField = sq.GetType().GetField( "queryInfo", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ); var queryInfo = queryInfoField.GetValue( sq ); ModifyQuery( queryInfo ); var subQueriesField = sq.GetType().GetField( "subQueries", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ); var nestedSubQueries = subQueriesField.GetValue( sq ); if ( nestedSubQueries != null ) ModifySubQueries( (IEnumerable)nestedSubQueries ); } } private void ModifyQuery( object q /* QueryInfo */ ) { var commandTextField = q.GetType().GetField( "commandText", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ); var parametersField = q.GetType().GetField( "parameters", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ); var commandText = (string)commandTextField.GetValue( q ); var parameterInfos = parametersField.GetValue( q ); var parameters = new Dictionary<string, object>(); foreach ( var p in (IEnumerable)parameterInfos ) { var param = Invoke( p, "get_Parameter" ); var name = (string)Invoke( param, "get_Name" ); parameters[ name ] = Invoke( p, "get_Value" ); } var modifiedCommandText = modifyCommand( commandText, parameters ); // System.Diagnostics.Trace.WriteLine( modifiedCommandText + "\n---" ); commandTextField.SetValue( q, modifiedCommandText ); } /* protected internal DbCommand GetCommandImpl( Expression query ) { return (DbCommand)Invoke( oldProvider, "GetCommand", query ); } protected internal string GetQueryTextImpl( Expression query ) { return (string)Invoke( oldProvider, "GetQueryText", query ); } */ internal interface IProviderProxy { ReflectionDataContextInterceptor Interceptor { get; } object OldProvider { get; } } public class ProviderProxy : RealProxy, IRemotingTypeInfo, IProviderProxy { public ReflectionDataContextInterceptor Interceptor { get; private set; } public object OldProvider { get; private set; } internal ProviderProxy( ReflectionDataContextInterceptor extender, object oldProvider ) : base( typeof( ContextBoundObject ) ) { this.Interceptor = extender; this.OldProvider = oldProvider; } public override IMessage Invoke( IMessage msg ) { if ( msg is IMethodCallMessage ) { IMethodCallMessage call = (IMethodCallMessage)msg; MethodInfo mi = null; if ( call.MethodBase.DeclaringType.Name == "IProvider" && call.MethodBase.DeclaringType.IsInterface ) { Interceptor.providerType = call.MethodBase.DeclaringType; mi = ReflectionDataContextInterceptor.GetMethod( typeof( ReflectionDataContextInterceptor ), call.MethodBase.Name + "Impl", call.Args ); if ( mi == null && OldProvider != null ) { mi = ReflectionDataContextInterceptor.GetMethod( call.MethodBase.DeclaringType, call.MethodBase.Name, call.Args ); try { return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call ); } catch ( TargetInvocationException e ) { return new ReturnMessage( e.InnerException, call ); } } if ( mi != null ) { try { return new ReturnMessage( mi.Invoke( this.Interceptor, call.Args ), null, 0, null, call ); } catch ( TargetInvocationException e ) { return new ReturnMessage( e.InnerException, call ); } } } // else mi = typeof( ReflectionDataContextInterceptor ).GetMethod( call.MethodBase.Name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic ); else { mi = ReflectionDataContextInterceptor.GetMethod( OldProvider.GetType(), call.MethodBase.Name, call.Args ); if ( mi != null ) { try { return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call ); } catch ( TargetInvocationException e ) { return new ReturnMessage( e.InnerException, call ); } } } throw new NotImplementedException( string.Format( "Method not found: {0}( {1} )", call.MethodBase.Name, string.Join( ", ", call.Args.Select( a => Convert.ToString( a ) ) ) ) ); } throw new NotImplementedException(); } public bool CanCastTo( Type fromType, object o ) { return true; } public string TypeName { get { return this.GetType().Name; } set { } } } }
It’s still evil, but works great
4 comments
Comments feed for this article
July 6, 2012 at 8:46 am
Modifying LINQ To SQL command text « Chris Cavanagh's Blog
[…] UPDATE 2 – Improved version that correctly handles sub-queries can be found here. […]
October 26, 2012 at 3:13 pm
Mark
Nice Work! Has anyone figured out how to modify the parameters at this point? Trying to work around the linq assuming everything is nvarchar issue and would like to change the parameter sqltype in the ModifyQuery method.
September 6, 2018 at 3:27 am
Kris Bennett
I know this is old, but why do you set the provider back?? … I know it fails when you don’t … but why should it?
September 6, 2018 at 11:54 pm
Chris Cavanagh
Kris – It’s been a while since I looked at this, but honestly I’m not sure why it sets it back at that point! What happens when you remove it?