You are currently browsing the category archive for the ‘.NET’ category.

The Immutable Collections package is awesome.  If you’re writing concurrent code, immutable collections should be your new best friends (along with immutable classes in general).  Explicit locking is bad, bad, bad, with an extra helping of bad (repeat this until it sticks).

A typical pattern you’ll see when modifying a collection looks like this:

myCollection = myCollection.Add( "Banana" );

However if the “myCollection” above is a field you’re sharing between threads, you still need to protect it.  This is easy with the System.Threading.Interlocked helpers:

Interlocked.Exchange(
    ref myCollection,
    myCollection.Add( "Banana" ) );

But what if you’re updating an ImmutableDictionary and need an atomic update, so it’ll only add an item if it doesn’t exist?  Here’s where the ImmutableInterlocked helpers come in:

private ImmutableDictionary<string, Fruit> myDictionary
    = ImmutableDictionary<string, Fruit>.Empty;
…
var fruit = ImmutableInterlocked.GetOrAdd(
    ref myDictionary,
    "banana",
    new Banana() );

Now things could get interesting.  You’ll notice on the line above we create a new Banana instance.  If “banana” already exists in the dictionary, the nice fresh Banana we created will just be discarded.  In many cases this isn’t a problem (maybe a slip hazard); it’s just a redundant object creation.

But what if it’s something we only want to create once, and only if it doesn’t exist?  ImmutableInterlocked has a GetOrAdd override that takes a delegate:

var fruit = ImmutableInterlocked.GetOrAdd(
    ref myDictionary,
    "banana",
    _ => new Banana() );

It sure looks promising.  Presumably it only calls the delegate if the item isn’t in the dictionary?…  Nope!  Apparently it always calls the delegate, checks if the item exists, and discards the result if it does (while the source code isn’t currently available, we can get a vague idea how it might be implemented from this Unofficial port of Immutable Collections).

So it seems we need another solution.  We really don’t want to explicitly lock anything (bad, bad, bad).  Turns out we can get this for “free” if we use Lazy<T>:

private ImmutableDictionary<string, Lazy<Fruit>> myDictionary
    = ImmutableDictionary<string, Lazy<Fruit>>.Empty;
…
var fruit = ImmutableInterlocked.GetOrAdd(
    ref myDictionary,
    "banana",
    new Lazy<Fruit>( () => new Banana(), true ) ).Value;

This ensures there’s a Lazy<Fruit> in the dictionary that knows how to create a Banana on demand.  Lazy<T> already takes care of ensuring only one thread can actually create the instance.  It does some internal locking of its own, but apparently it’s super efficient so we can happily ignore it and go on our way.

Hope this helps!

Check this out – http://referencesource-beta.microsoft.com/

Using the new and cool Roslyn “Compiler as a Service” as [a kind of] online Reflector / DotPeek / ILSpy:)

Here’s a simple view engine for ASP.NET MVC that lets you use plain HTML for your views, even if it’s badly formed!  It supports a very simple attribute syntax for embedding other partial views in the page; those views can use whichever view engine you’d like (WebForms, Razor, NHaml etc).

Let’s start with the composite view engine; its job is to find a container view based on the Master page name, but also find a primary partial view based on the current MVC controller and action:

public interface ICompositeView
{
    IView PrimaryView { set; }
}

public abstract class CompositeViewEngine : VirtualPathProviderViewEngine
{
    private ViewEngineCollection otherViewEngines;

    public CompositeViewEngine()
    {
    }

    public override ViewEngineResult FindView( ControllerContext controllerContext, string viewName, string masterName, bool useCache )
    {
        if ( !controllerContext.IsChildAction )
        {
            var result = base.FindView( controllerContext, GetMasterName( controllerContext.RouteData.Values, masterName ), null, useCache );
            var compositeView = result.View as ICompositeView;

            if ( compositeView != null )
            {
                compositeView.PrimaryView = OtherViewEngines.FindPartialView( controllerContext, viewName ).View;

                return result;
            }
        }
        else
        {
            return OtherViewEngines.FindView( controllerContext, viewName, null );
        }

        return new ViewEngineResult( Enumerable.Empty<string>() );
    }

    public override ViewEngineResult FindPartialView( ControllerContext controllerContext, string partialViewName, bool useCache )
    {
        return new ViewEngineResult( Enumerable.Empty<string>() );
    }

    private ViewEngineCollection OtherViewEngines
    {
        get
        {
            lock ( this )
            {
                return ( otherViewEngines != null )
                    ? otherViewEngines
                    : otherViewEngines = new ViewEngineCollection( ViewEngines.Engines.Where( e => !( e is CompositeViewEngine ) ).ToList() );
            }
        }
    }

    protected virtual string GetMasterName( RouteValueDictionary routeValues, string defaultName )
    {
        return defaultName;
    }
}

After finding its own container view, it gives all other view engines the opportunity to find the contained partial.  Here’s the HTML view engine that derives from it:

public class HtmlViewEngine : CompositeViewEngine
{
    public IHtmlViewHelper Helper { get; set; }

    public HtmlViewEngine()
    {
        this.AreaViewLocationFormats = new string[]
        {
            "~/Areas/{2}/Views/{1}/{0}.html",
            "~/Areas/{2}/Views/{1}/{0}.htm",
            "~/Areas/{2}/Views/Shared/{0}.html",
            "~/Areas/{2}/Views/Shared/{0}.htm"
        };

        this.ViewLocationFormats = new string[]
        {
            "~/Views/{1}/{0}.html",
            "~/Views/{1}/{0}.htm",
            "~/Views/Shared/{0}.html",
            "~/Views/Shared/{0}.htm"
        };

        this.FileExtensions = new string[]
        {
            "html",
            "htm"
        };
    }

    protected override string GetMasterName( RouteValueDictionary routeValues, string defaultName )
    {
        return routeValues.ContainsKey( "path" )
            ? ( (string)routeValues[ "path" ] ).Split( '.' ).First()
            : !string.IsNullOrEmpty( defaultName ) ? defaultName : "Index";
    }

    protected override IView CreateView( ControllerContext controllerContext, string viewPath, string masterPath )
    {
        return new HtmlView( viewPath, Helper );
    }

    protected override IView CreatePartialView( ControllerContext controllerContext, string partialPath )
    {
        return new HtmlView( partialPath, Helper );
    }
}

Once the container view is found, it creates an HtmlView that knows how to render it.  Here’s how that looks (starting with a CompositeView):

public abstract class CompositeView : IView, ICompositeView
{
    protected string filename;

    public IView PrimaryView { get; set; }

    public CompositeView( string filename )
    {
        this.filename = filename;
    }

    public abstract void Render( ViewContext viewContext, TextWriter writer );
}
public interface IHtmlViewHelper
{
    void RenderContent( HtmlDocument document, ViewRenderer renderer );
}

public class HtmlView : CompositeView
{
    protected IHtmlViewHelper helper;
    protected HtmlDocument source;

    public HtmlView( string filename, IHtmlViewHelper helper )
        : base( filename )
    {
        this.helper = helper;
    }

    public override void Render( ViewContext viewContext, TextWriter writer )
    {
        var document = GetSource( viewContext );

        if ( helper != null )
        {
            var viewDataContainer = new ViewDataContainer( viewContext.ViewData.Model );
            var htmlHelper = new HtmlHelper( viewContext, viewDataContainer );

            helper.RenderContent( document, new ViewRenderer( viewContext, htmlHelper, PrimaryView ) );
        }

        document.Save( writer );
    }

    private HtmlDocument GetSource( ControllerContext controllerContext )
    {
        return source ?? ( source = GetSource( controllerContext.HttpContext, filename ) );
    }

    private HtmlDocument GetSource( HttpContextBase httpContext, string filename )
    {
        return httpContext.RequestCache().Cache( filename, () => LoadSource( httpContext, filename ) );
    }

    private HtmlDocument LoadSource( HttpContextBase httpContext, string filename )
    {
        var doc = new HtmlDocument();

        doc.Load( httpContext.Server.MapPath( filename ) );

        return doc;
    }
}

It uses HtmlAgilityPack to parse the HTML (with a little caching), injects new content into the DOM, then renders the result with some help from the ViewRenderer class (catches & pretty prints any rendering errors too):

public class ViewRenderer
{
    private ViewContext viewContext;
    private HtmlHelper htmlHelper;
    private IView primaryView;
    private ViewEngineCollection otherViewEngines;

    public ViewRenderer( ViewContext viewContext, HtmlHelper htmlHelper, IView primaryView )
    {
        this.viewContext = viewContext;
        this.htmlHelper = htmlHelper;
        this.primaryView = primaryView;
        this.otherViewEngines = new ViewEngineCollection( ViewEngines.Engines.Where( e => !( e is CompositeViewEngine ) ).ToList() );
    }

    public MvcHtmlString RenderContent( bool usePrimaryView, string actionName = null, string controllerName = null, string viewName = null )
    {
        var rendered = ( viewName != null )
            ? RenderView( viewName )
            : null;

        if ( rendered == null && usePrimaryView && ( controllerName == null || controllerName == (string)viewContext.RouteData.Values[ "controller" ] ) )
        {
            rendered = RenderView( primaryView );
        }

        if ( rendered == null ) rendered = RenderAction( actionName ?? "Index", controllerName );

        return rendered ?? MvcHtmlString.Empty;
    }

    public MvcHtmlString RenderView( string viewName )
    {
        return RenderView( FindView( viewName ) );
    }

    public MvcHtmlString RenderAction( string actionName, string controllerName = null )
    {
        MvcHtmlString result = null;

        try
        {
            result = htmlHelper.Action( actionName, controllerName );
        }
        catch ( HttpException ex )
        {
            result = MvcHtmlString.Create( ex.GetHtmlErrorMessage() ?? new HttpUnhandledException( ex.Message, ex.InnerException ).GetHtmlErrorMessage() );
        }
        catch ( Exception ex )
        {
            result = MvcHtmlString.Create( new HttpUnhandledException( ex.Message ).GetHtmlErrorMessage() );
        }

        return result;
    }

    private IView FindView( string viewName )
    {
        var result = otherViewEngines.FindPartialView( viewContext, viewName );

        return ( result.View != null ) ? result.View : null;
    }

    private MvcHtmlString RenderView( IView view )
    {
        if ( view == null ) return null;

        using ( var writer = new StringWriter() )
        {
            var renderViewContext = new ViewContext( viewContext, view, viewContext.ViewData, viewContext.TempData, writer );

            try
            {
                view.Render( renderViewContext, writer );
            }
            catch ( HttpException ex )
            {
                writer.Write( ex.GetHtmlErrorMessage() ?? new HttpUnhandledException( ex.Message, ex.InnerException ).GetHtmlErrorMessage() );
            }
            catch ( Exception ex )
            {
                writer.Write( new HttpUnhandledException( ex.Message ).GetHtmlErrorMessage() );
            }

            return MvcHtmlString.Create( writer.ToString() );
        }
    }
}

Finally we need to tell MVC about the view engine. Similar to the RouteConfig class you’ll see in a new MVC 4 project, here’s ViewEngineConfig:

public class ViewEngineConfig
{
    public static void RegisterEngines( ViewEngineCollection viewEngines )
    {
        viewEngines.Insert( 0, new HtmlViewEngine()
        {
            Helper = new HtmlViewHelper()
        } );
    }

    private class HtmlViewHelper : IHtmlViewHelper
    {
        public void RenderContent( HtmlDocument document, ViewRenderer renderer )
        {
            foreach ( var node in SelectNodes( document.DocumentNode, "//*[@html-primary or @html-controller or @html-action]" ) )
            {
                var isPrimary = node.GetAttributeValue( "html-primary", false );
                var controllerName = node.GetAttributeValue( "html-controller", null );
                var actionName = node.GetAttributeValue( "html-action", null );

                node.InnerHtml = renderer.RenderContent( isPrimary, actionName, controllerName ).ToHtmlString();
            }

            foreach ( var node in SelectNodes( document.DocumentNode, "//*[@html-partial]" ) )
            {
                node.InnerHtml = ( renderer.RenderView( node.Attributes[ "html-partial" ].Value ) ?? MvcHtmlString.Empty ).ToHtmlString();
            }
        }

        public string GetControllerName( HtmlDocument document )
        {
            var controllerNode = document.DocumentNode.SelectSingleNode( "//*[@html-controller]" );

            return ( controllerNode != null ) ? controllerNode.GetAttributeValue( "html-controller", null ) : null;
        }

        private static IEnumerable<HtmlNode> SelectNodes( HtmlNode node, string xpath )
        {
            return node.SelectNodes( xpath ) ?? Enumerable.Empty<HtmlNode>();
        }
    }
}

This is doing most of the content substitution.  It’s looking for a few pre-defined attributes in the HTML (html-primary, html-partial, html-controller and html-action) and replacing the content as needed.

Call RegisterEngines in Application_Start (in Global.asax.cs) and you’re done:

ViewEngineConfig.RegisterEngines( ViewEngines.Engines );

This works great when you’re just using regular MVC controller / action routes, but what if you want to handle direct requests for the HTML views? (for example if you’re hosting an entire static site within your MVC project… not as odd as it might sound).  We can do this by adding some Route definitions:

// Controller prefixed resources
routes.Add( "ControllerStaticResource", new Route( @"{controller}/{*path}", new StaticFileRouteHandler() )
{
    Constraints = new RouteValueDictionary( new { path = @".*\.(css|js|png|jpg|gif)" } ),
    Defaults = new RouteValueDictionary( new { rootFolder = "~/Views", folder = "Shared" } ),
} );

// Resources
routes.Add( "StaticResource", new Route( @"{*path}", new StaticFileRouteHandler() )
{
    Constraints = new RouteValueDictionary( new { path = @".*\.(css|js|png|jpg|gif)" } ),
    Defaults = new RouteValueDictionary( new { rootFolder = "~/Views", folder = "Shared" } )
} );

// Static HTML path with controller and action prefix
routes.Add( "ControllerActionStaticHtml", new PlaceholderRoute( @"{controller}/{action}/{*path}", handler )
{
    Constraints = new RouteValueDictionary( new { path = @".*\.(html|htm)" } ),
    Excludes = new[] { "path" }
} );

// Static HTML path with controller prefix
routes.Add( "ControllerStaticHtml", new PlaceholderRoute( @"{controller}/{*path}", handler )
{
    Constraints = new RouteValueDictionary( new { path = @".*\.(html|htm)" } ),
    Defaults = new RouteValueDictionary( new { controller = "Home", action = "Index", path = UrlParameter.Optional } ),
    Excludes = new[] { "path" }
} );

// Static HTML path with controller prefix
routes.Add( "StaticHtml", new PlaceholderRoute( @"{*path}", handler )
{
    Constraints = new RouteValueDictionary( new { path = @".*\.(html|htm)" } ),
    Defaults = new RouteValueDictionary( new { controller = "Home", action = "Index", path = UrlParameter.Optional } ),
    Excludes = new[] { "path" }
} );

// Static HTML path with controller and action prefix
routes.Add( "ControllerActionStaticHtmlGenerate", new PlaceholderRoute( @"{controller}/{action}/{path}", handler )
{
    Defaults = new RouteValueDictionary( new { controller = "Home", action = UrlParameter.Optional, path = UrlParameter.Optional } ),
    Placeholders = new RouteValueDictionary( new { action = "Index" } ),
    Excludes = new[] { "path" }
} );

This could be more complex than you need, so don’t freak out just yet Smile  The first couple are intercepting css, js, png etc files, and pointing them to a new StaticFileRouteHandler class.  Next we’re looking for .html and .html files, but letting the regular MvcRouteHandler take care of those.

Browsers expect non-absolute resource paths to be relative to the current request URL.  In this example we have CSS files, images etc in subfolders of the Views/Shared folder, along with the composite HTML files.  However the URL the browser sees might be just an MVC path.  The StaticFileRouteHandler lets us intercept those resource requests and grab the files from the appropriate place.

This isn’t a requirement (resources could be in the typical ~/Content folder if you prefer) but it can be pretty convenient. By keeping the embedded site files together they can be modified with any HTML editor. If a third party is responsible for those, you can just drop in the entire site when they make changes.

Here’s StaticFileRouteHandler:

public class StaticFileRouteHandler : IRouteHandler
{
    public IHttpHandler GetHttpHandler( RequestContext requestContext )
    {
        return new StaticFileHttpHandler( requestContext );
    }

    public class StaticFileHttpHandler : IHttpAsyncHandler, IHttpHandler //, IRequiresSessionState
    {
        private delegate void AsyncProcessorDelegate( HttpContext httpContext );

        protected RequestContext requestContext;
        private AsyncProcessorDelegate asyncDelegate;

        public StaticFileHttpHandler( RequestContext requestContext )
        {
            this.requestContext = requestContext;
        }

        public void ProcessRequest( HttpContext context )
        {
            var routeValues = requestContext.RouteData.Values;

            var controllerName = (string)routeValues[ "controller" ];
            var folderName = (string)routeValues[ "folder" ];
            var path = GetFilePath( context );

            var filePath = ( controllerName != null )
                ? FindFilePath( controllerName, path ) ?? FindFilePath( folderName, path )
                : FindFilePath( folderName, path );

            if ( filePath != null )
            {
                var response = context.Response;

                response.ContentType = GetContentType( filePath );
                response.AddFileDependency( filePath );
                response.Cache.SetETagFromFileDependencies();
                response.Cache.SetLastModifiedFromFileDependencies();
                response.Cache.SetCacheability( HttpCacheability.Public );

                context.Response.TransmitFile( filePath );
            }
            else
            {
                System.Diagnostics.Trace.WriteLine( string.Format( "ERROR: StaticRouteHandler couldn't find {0}", context.Request.Url ) );
                context.Response.StatusCode = 404;
            }
        }

        private string GetFilePath( HttpContext context )
        {
            var routeValues = requestContext.RouteData.Values;

            if ( context.Request.UrlReferrer == null ) return (string)routeValues[ "path" ];

            var urlBase = "http://" + context.Request.Url.GetComponents( UriComponents.Host | UriComponents.Path, UriFormat.Unescaped );
            var referrerBase = "http://" + context.Request.UrlReferrer.GetComponents( UriComponents.Host | UriComponents.Path, UriFormat.Unescaped );

            var url = new Uri( urlBase, UriKind.Absolute );
            var referrer = new Uri( referrerBase, UriKind.Absolute );

            return referrer.MakeRelativeUri( url ).OriginalString;
        }

        private string FindFilePath( string folderName, string path )
        {
            var httpContext = requestContext.HttpContext;
            var routeValues = requestContext.RouteData.Values;

            var filePath = string.Format( "{0}/{1}/{2}",
                routeValues[ "rootFolder" ],
                folderName,
                path );

            var absolutePath = httpContext.Server.MapPath( filePath );

            System.Diagnostics.Trace.WriteLine( string.Format( "Looking for file in {0}", absolutePath ) );

            return File.Exists( absolutePath ) ? absolutePath : null;
        }

        private string GetContentType( string filePath )
        {
            var extension = System.IO.Path.GetExtension( filePath );

            switch ( extension )
            {
                case ".htm":
                case ".html": return "text/html";
                case ".css": return "text/css";
                case ".js": return "application/javascript";
                case ".png": return "image/png";
                case ".jpg": return "image/jpeg";
                case ".gif": return "image/gif";
            }

            return "text/plain";
        }

        public IAsyncResult BeginProcessRequest( HttpContext context, AsyncCallback cb, object extraData )
        {
            asyncDelegate = ProcessRequest;

            return asyncDelegate.BeginInvoke( context, cb, extraData );
        }

        public void EndProcessRequest( IAsyncResult result )
        {
            asyncDelegate.EndInvoke( result );
        }

        public bool IsReusable
        {
            get { return true; }
        }
    }
}

Full source and sample project coming soon! Smile

Anyone in or around Rochester, MN on November 14th 2013 with even the slightest interest in Xamarin development, should attend this…

http://rochmndotnetug201311-es2.eventbrite.com/?rank=1&sid=5e940311459b11e3a5871231391ec9bf

Smile

Looks like still time to register for TCCC15! – http://tccc15.eventbrite.com

It’s sure to be another awesome day Smile

ASP.NET MVC’s inbuilt Route class can handle just about anything you want.  However if you need a bit more control it’s easy to derive your own.  Here’s a simple class called DirectionalRoute that adds a couple of features:

  • CanGetRouteData: Set if this route can be used to parse a URL
  • CanGetVirtualPath: Set if this route can be used to generate a URL
  • Placeholders: Dictionary of placeholder values
  • Excludes: Array of keys to exclude from generated URL

The Placeholders dictionary needs a little explanation.  Setting Default values on a regular route works well, but if your RouteData contains a value that matches a default (or if the default is set to UrlParameter.Optional), it could be excluded completely from generated URLs. In situations where you really need a value to be present, set a Placeholder.

Here’s the code:

public class DirectionalRoute : Route
{
    public bool CanGetRouteData { get; set; }
    public bool CanGetVirtualPath { get; set; }
    public RouteValueDictionary Placeholders { get; set; }
    public string[] Excludes { get; set; }

    public DirectionalRoute(string url, IRouteHandler routeHandler)
        : this( url, null, null, null, routeHandler )
    {
    }

    public DirectionalRoute(string url, RouteValueDictionary defaults, IRouteHandler routeHandler)
        : this( url, defaults, null, null, routeHandler )
    {
    }

    public DirectionalRoute(string url, RouteValueDictionary defaults, RouteValueDictionary constraints, IRouteHandler routeHandler)
        : this( url, defaults, constraints, null, routeHandler )
    {
    }

    public DirectionalRoute( string url, RouteValueDictionary defaults, RouteValueDictionary constraints, RouteValueDictionary dataTokens, IRouteHandler routeHandler )
        : base( url, defaults, constraints, dataTokens, routeHandler )
    {
        this.CanGetRouteData = true;
        this.CanGetVirtualPath = true;
    }

    public override RouteData GetRouteData( System.Web.HttpContextBase httpContext )
    {
        if ( !CanGetRouteData ) return null;

        var routeData = base.GetRouteData( httpContext );

        if ( routeData != null && Placeholders != null )
        {
            var missing = routeData.Values
                .Where( rv => ( rv.Value == null || rv.Value == UrlParameter.Optional ) && Placeholders.ContainsKey( rv.Key ) )
                .ToArray();

            foreach ( var m in missing ) routeData.Values[ m.Key ] = Placeholders[ m.Key ];
        }

        return routeData;
    }

    public override VirtualPathData GetVirtualPath( RequestContext requestContext, RouteValueDictionary values )
    {
        return CanGetVirtualPath
            ? base.GetVirtualPath( GetRequestContext( requestContext ), GetRouteValues( values ) )
            : null;
    }

    private RequestContext GetRequestContext( RequestContext requestContext )
    {
        if ( Excludes == null || Excludes.Length == 0 ) return requestContext;

        var newRouteData = new RouteData( requestContext.RouteData.Route, requestContext.RouteData.RouteHandler );

        foreach ( var v in requestContext.RouteData.Values.Where( v => !Excludes.Contains( v.Key ) ) ) newRouteData.Values[ v.Key ] = v.Value;
        foreach ( var v in requestContext.RouteData.DataTokens ) newRouteData.DataTokens[ v.Key ] = v.Value;

        return new RequestContext( requestContext.HttpContext, newRouteData );
    }

    private RouteValueDictionary GetRouteValues( RouteValueDictionary values )
    {
        if ( Excludes == null || Excludes.Length == 0 ) return values;

        return new RouteValueDictionary( values.Where( v => !Excludes.Contains( v.Key ) ).ToDictionary( v => v.Key, v => v.Value ) );
    }
}

TCCC14 (off-by-one error?) takes place on April 27th.  Go register!  Even if you just want the free breakfast (donuts, coffee… what more does a geek need?), campus atmosphere and chance to win prizes, do it! Smile  It doesn’t even matter if some of the presentations aren’t in your usual fields; if you can’t decide, go to a random or fun-looking one!

http://www.twincitiescodecamp.com/TCCC/Default.aspx

I’ve been to the last two, and they’re great!

Update: Source code now available on github, and also as an npm package!

What happens when you put node.js and socket.io on a Raspberry Pi, then use them as goop between Box2D (physics engine) demos?

All kinds of awesome Smile – Try it here (not sure how long I’ll keep this running, so get it while it’s hot).

Open it in two browsers, and see how long you can keep them in sync.

image

You can find the source on github, and also as an npm package.

Smile

Following my post last year about modifying LINQ to SQL command text (evil, as it calls private methods through reflection) here’s an equally evil, but faster version that pre-compiles most of its work through an expression tree:

public delegate string ModifyCommandDelegate( string commandText, IDictionary<string, object> parameters );

public class DataContextInterceptor
{
    private DataContext dc;
    private object oldProvider;
    private Type providerType;
    private ModifyCommandDelegate modifyCommand;

    private static Func<object, ModifyCommandDelegate, DataContext, Func<Expression, object>> CompileFactory = CreateCompileFactory().Compile();
    private static Func<object, ModifyCommandDelegate, Func<Expression, IExecuteResult>> ExecuteFactory = CreateExecuteFactory().Compile();

    public static TDataContext Intercept<TDataContext>( TDataContext dc, ModifyCommandDelegate modifyCommand )
        where TDataContext : DataContext
    {
        new DataContextInterceptor( dc, modifyCommand );

        //            typeof( Expression ).GetProperty( "DebugView", BindingFlags.Instance | BindingFlags.NonPublic ).GetValue( executeFactoryExp, null ).Dump();

        return dc;
    }

    public DataContextInterceptor( DataContext dc, ModifyCommandDelegate modifyCommand )
    {
        this.dc = dc;
        this.modifyCommand = modifyCommand;

        FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
        var existingProvider = providerField.GetValue( dc );

        if ( existingProvider is IProviderProxy )
        {
            // System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} already intercepted", dc.GetHashCode() ) );
        }
        else
        {
            oldProvider = existingProvider;

            var proxy = new ProviderProxy( this, oldProvider ).GetTransparentProxy();

            providerField.SetValue( dc, proxy );

            // System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} intercepted", dc.GetHashCode() ) );
        }
    }

    public static MethodCallExpression MakeMethodCall( Type type, string methodName, params Expression[] arguments )
    {
        // ( "Making MethodCallExpression for " + methodName ).Dump();

        var methodInfo = type.GetMethod(
            methodName,
            BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic,
            null,
            arguments.Select( a => a.Type ).ToArray(),
            null );

        if ( methodInfo == null ) throw new ArgumentException( string.Format( "Unable to find method {0}.{1}", type.Name, methodName ), "methodName" );

        return Expression.Call( methodInfo, arguments );
    }

    public static MethodCallExpression MakeMethodCall( Expression instance, string methodName, params Expression[] arguments )
    {
        // ( "Making MethodCallExpression for " + instance.ToString() + "." + methodName ).Dump();

        var methodInfo = instance.Type.GetMethod(
            methodName,
            BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
            null,
            arguments.Select( a => a.Type ).ToArray(),
            null );

        if ( methodInfo == null ) throw new ArgumentException( string.Format( "Unable to find method {0}.{1}", instance.Type.Name, methodName ), "methodName" );

        return Expression.Call( instance, methodInfo, arguments );
    }

    protected internal object Compile( Expression query )
    {
        return DataContextInterceptor.CompileFactory( oldProvider, modifyCommand, dc )( query );
    }

    protected internal virtual IExecuteResult Execute( Expression query )
    {
        return DataContextInterceptor.ExecuteFactory( oldProvider, modifyCommand )( query );
    }

    public static Expression<Func<object, ModifyCommandDelegate, DataContext, Func<Expression, object>>> CreateCompileFactory()
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var providerType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" );
        var providerParam = Expression.Parameter( typeof( object ), "provider" );
        var modifyCommandParam = Expression.Parameter( typeof( ModifyCommandDelegate ), "modifyCommand" );
        var dataContextParam = Expression.Parameter( typeof( DataContext ), "dataContext" );

        return Expression.Lambda<Func<object, ModifyCommandDelegate, DataContext, Func<Expression, object>>>(
            CreateCompileMethod( Expression.Convert( providerParam, providerType ), modifyCommandParam, dataContextParam ),
            providerParam,
            modifyCommandParam,
            dataContextParam );
    }

    public static Expression<Func<Expression, object>> CreateCompileMethod( Expression oldProvider, Expression modifyCommand, Expression dataContextParam )
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var providerType = assembly.GetType( "System.Data.Linq.Provider.IProvider" );
        var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" );
        var annotationsType = assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" );
        var queryInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "QueryInfo", BindingFlags.NonPublic );
        var readerFactoryType = assembly.GetType( "System.Data.Linq.SqlClient.IObjectReaderFactory" );
        var resultShapeType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "ResultShape", BindingFlags.NonPublic );
        var compiledSubQueryType = assembly.GetType( "System.Data.Linq.SqlClient.ICompiledSubQuery" );
        var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
        var compiledQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider+CompiledQuery" );

        var queryParam = Expression.Variable( typeof( Expression ), "query" );

        var providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );

        var annotationsVar = Expression.Variable( annotationsType, "annotations" );
        var queriesVar = Expression.Variable( queryInfoType.MakeArrayType(), "queries" );
        var infoVar = Expression.Variable( queryInfoType, "info" );
        var lambdaVar = Expression.Variable( typeof( LambdaExpression ), "lambda" );
        var readerFactoryVar = Expression.Variable( readerFactoryType, "readerFactory" );
        var subQueriesVar = Expression.Variable( compiledSubQueryType.MakeArrayType(), "subQueries" );
        var resultShapeVar = Expression.Variable( resultShapeType, "resultShape" );
        var returnTarget = Expression.Label( compiledQueryType );

        var getQuery = MakeMethodCall( infoVar, "get_Query" );
        var getResultType = MakeMethodCall( infoVar, "get_ResultType" );
        var getResultElementType = MakeMethodCall( typeSystemType, "GetElementType", getResultType );
        var modifySubQueries = CreateModifySubQueriesMethod();
        var compiledQueryConstructor = compiledQueryType.GetConstructors( BindingFlags.NonPublic | BindingFlags.Instance ).First();

        return Expression.Lambda<Func<Expression, object>>(
            Expression.Block(
                new[] { annotationsVar, queriesVar, infoVar, lambdaVar, readerFactoryVar, subQueriesVar, resultShapeVar },

                // this.InitializeProviderMode();
                MakeMethodCall( oldProvider, "InitializeProviderMode" ),

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                Expression.Assign( annotationsVar, Expression.New( annotationsType ) ),

                // QueryInfo[] queries = this.BuildQuery( query, annotations );
                Expression.Assign( queriesVar, MakeMethodCall( oldProvider, "BuildQuery", queryParam, annotationsVar ) ),

                // var info = ModifyQueries( (IEnumerable)queries );
                Expression.Assign( infoVar, Expression.Invoke( CreateModifyQueriesMethod(), queriesVar, modifyCommand ) ),

                // this.CheckSqlCompatibility(queries, annotations);
                MakeMethodCall( oldProvider, "CheckSqlCompatibility", queriesVar, annotationsVar ),

                // var lambda = query as LambdaExpression;
                Expression.Assign( lambdaVar, Expression.TypeAs( queryParam, lambdaVar.Type ) ),

                // if ( lambda != null )
                Expression.IfThen(
                    Expression.NotEqual( lambdaVar, Expression.Constant( null ) ),
                    // query = lambda.Body;
                    Expression.Assign( queryParam, Expression.Property( lambdaVar, "Body" ) ) ),

                // IObjectReaderFactory readerFactory = null;
                // ICompiledSubQuery[] subQueries = null;

                // QueryInfo info = queries[queries.Length - 1];
                // info defined above
                // var resultShape = (int)Invoke( info, "get_ResultShape" );
                Expression.Assign( resultShapeVar, MakeMethodCall( infoVar, "get_ResultShape" ) ),

                Expression.Switch(
                    Expression.Convert( resultShapeVar, typeof( int ) ),
                    // if ( resultShape == 1 /* ResultShape.Singleton */ )
                    Expression.SwitchCase(
                        Expression.Block(
                            // subQueries = this.CompileSubQueries( info.Query );
                            Expression.Assign( subQueriesVar, MakeMethodCall( oldProvider, "CompileSubQueries", getQuery ) ),
                            // ModifySubQueries( (IEnumerable)subQueries );
                            Expression.Invoke( modifySubQueries, subQueriesVar, modifyCommand ),
                            // readerFactory = this.GetReaderFactory( info.Query, info.ResultType );
                            Expression.Assign( readerFactoryVar, MakeMethodCall( oldProvider, "GetReaderFactory", getQuery, getResultType ) ),
                            Expression.Empty() ),
                        Expression.Constant( 1 ) ),
                    // else if ( resultShape == 2 /* ResultShape.Sequence */ )
                    Expression.SwitchCase(
                        Expression.Block(
                            // subQueries = this.CompileSubQueries( info.Query );
                            Expression.Assign( subQueriesVar, MakeMethodCall( oldProvider, "CompileSubQueries", getQuery ) ),
                            // ModifySubQueries( (IEnumerable)subQueries );
                            Expression.Invoke( modifySubQueries, subQueriesVar, modifyCommand ),
                            // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) );
                            Expression.Assign( readerFactoryVar, MakeMethodCall( oldProvider, "GetReaderFactory", getQuery, getResultElementType ) ),
                            Expression.Empty() ),
                        Expression.Constant( 2 ) ) ),

                // dc.provider = oldProvider;    // (Unfortunately needed to ensure compilation runs)
                Expression.Assign( Expression.MakeMemberAccess( dataContextParam, providerField ), oldProvider ),

                // return new CompiledQuery( oldProvider, query, queries, readerFactory, subQueries );
                Expression.Return( returnTarget, Expression.New( compiledQueryConstructor, oldProvider, queryParam, queriesVar, readerFactoryVar, subQueriesVar ) ),
                // --- Expression.Throw( Expression.New( typeof( Exception ).GetConstructor( new[] { typeof( string ) } ), Expression.Constant( "boo4" ) ) ),
                Expression.Label( returnTarget, Expression.Constant( null, compiledQueryType ) ) ),
            queryParam );
    }

    public static Expression<Func<object, ModifyCommandDelegate, Func<Expression, IExecuteResult>>> CreateExecuteFactory()
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var providerType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" );
        var providerParam = Expression.Parameter( typeof( object ), "provider" );
        var modifyCommandParam = Expression.Parameter( typeof( ModifyCommandDelegate ), "modifyCommand" );

        return Expression.Lambda<Func<object, ModifyCommandDelegate, Func<Expression, IExecuteResult>>>(
            CreateExecuteMethod( Expression.Convert( providerParam, providerType ), modifyCommandParam ),
            providerParam,
            modifyCommandParam );
    }

    public static Expression<Func<Expression, IExecuteResult>> CreateExecuteMethod( Expression oldProvider, Expression modifyCommand )
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var providerType = assembly.GetType( "System.Data.Linq.Provider.IProvider" );
        var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" );
        var annotationsType = assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" );
        var queryInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "QueryInfo", BindingFlags.NonPublic );
        var readerFactoryType = assembly.GetType( "System.Data.Linq.SqlClient.IObjectReaderFactory" );
        var resultShapeType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "ResultShape", BindingFlags.NonPublic );
        var compiledSubQueryType = assembly.GetType( "System.Data.Linq.SqlClient.ICompiledSubQuery" );
        var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );

        var queryParam = Expression.Variable( typeof( Expression ), "query" );

        var cachedResultVar = Expression.Variable( typeof( IExecuteResult ), "cachedResult" );
        var annotationsVar = Expression.Variable( annotationsType, "annotations" );
        var queriesVar = Expression.Variable( queryInfoType.MakeArrayType(), "queries" );
        var infoVar = Expression.Variable( queryInfoType, "info" );
        var lambdaVar = Expression.Variable( typeof( LambdaExpression ), "lambda" );
        var readerFactoryVar = Expression.Variable( readerFactoryType, "readerFactory" );
        var subQueriesVar = Expression.Variable( compiledSubQueryType.MakeArrayType(), "subQueries" );
        var resultShapeVar = Expression.Variable( resultShapeType, "resultShape" );
        var returnTarget = Expression.Label( typeof( IExecuteResult ) );

        var getQuery = MakeMethodCall( infoVar, "get_Query" );
        var getResultType = MakeMethodCall( infoVar, "get_ResultType" );
        var getResultElementType = MakeMethodCall( typeSystemType, "GetElementType", getResultType );
        var modifySubQueriesMethod = CreateModifySubQueriesMethod();

        return Expression.Lambda<Func<Expression, IExecuteResult>>(
            Expression.Block(
                new[] { annotationsVar, queriesVar, infoVar, lambdaVar, readerFactoryVar, subQueriesVar, resultShapeVar },

                // this.InitializeProviderMode();
                MakeMethodCall( oldProvider, "InitializeProviderMode" ),

                // query = Funcletizer.Funcletize(query);
                Expression.Assign( queryParam, MakeMethodCall( funcletizerType, "Funcletize", queryParam ) ),

                // if ( this.EnableCacheLookup )
                Expression.IfThen(
                    MakeMethodCall( oldProvider, "get_EnableCacheLookup" ),
                    Expression.Block(
                        new[] { cachedResultVar },
                        // IExecuteResult cachedResult = this.GetCachedResult(query);
                        Expression.Assign( cachedResultVar, MakeMethodCall( oldProvider, "GetCachedResult", queryParam ) ),
                        // if ( cachedResult != null )
                        Expression.IfThen(
                            Expression.NotEqual( cachedResultVar, Expression.Constant( null ) ),
                            // return cachedResult;
                            Expression.Return( returnTarget, cachedResultVar ) ) ) ),

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                Expression.Assign( annotationsVar, Expression.New( annotationsType ) ),

                // QueryInfo[] queries = this.BuildQuery(query, annotations);
                Expression.Assign( queriesVar, MakeMethodCall( oldProvider, "BuildQuery", queryParam, annotationsVar ) ),

                // var info = ModifyQueries( (IEnumerable)queries );
                Expression.Assign( infoVar, Expression.Invoke( CreateModifyQueriesMethod(), queriesVar, modifyCommand ) ),

                // this.CheckSqlCompatibility(queries, annotations);
                MakeMethodCall( oldProvider, "CheckSqlCompatibility", queriesVar, annotationsVar ),

                // var lambda = query as LambdaExpression;
                Expression.Assign( lambdaVar, Expression.TypeAs( queryParam, lambdaVar.Type ) ),

                // if ( lambda != null )
                Expression.IfThen(
                    Expression.NotEqual( lambdaVar, Expression.Constant( null ) ),
                    // query = lambda.Body;
                    Expression.Assign( queryParam, Expression.Property( lambdaVar, "Body" ) ) ),

                // IObjectReaderFactory readerFactory = null;
                // ICompiledSubQuery[] subQueries = null;

                // QueryInfo info = queries[queries.Length - 1];
                // info defined above
                // var resultShape = (int)Invoke( info, "get_ResultShape" );
                Expression.Assign( resultShapeVar, MakeMethodCall( infoVar, "get_ResultShape" ) ),

                Expression.Switch(
                    Expression.Convert( resultShapeVar, typeof( int ) ),
                    // if ( resultShape == 1 /* ResultShape.Singleton */ )
                    Expression.SwitchCase(
                        Expression.Block(
                            // subQueries = this.CompileSubQueries( info.Query );
                            Expression.Assign( subQueriesVar, MakeMethodCall( oldProvider, "CompileSubQueries", getQuery ) ),
                            // ModifySubQueries( (IEnumerable)subQueries );
                            Expression.Invoke( modifySubQueriesMethod, subQueriesVar, modifyCommand ),
                            // readerFactory = this.GetReaderFactory( info.Query, info.ResultType );
                            Expression.Assign( readerFactoryVar, MakeMethodCall( oldProvider, "GetReaderFactory", getQuery, getResultType ) ),
                            Expression.Empty() ),
                        Expression.Constant( 1 ) ),
                    // else if ( resultShape == 2 /* ResultShape.Sequence */ )
                    Expression.SwitchCase(
                        Expression.Block(
                            // subQueries = this.CompileSubQueries( info.Query );
                            Expression.Assign( subQueriesVar, MakeMethodCall( oldProvider, "CompileSubQueries", getQuery ) ),
                            // ModifySubQueries( (IEnumerable)subQueries );
                            Expression.Invoke( modifySubQueriesMethod, subQueriesVar, modifyCommand ),
                            // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) );
                            Expression.Assign( readerFactoryVar, MakeMethodCall( oldProvider, "GetReaderFactory", getQuery, getResultElementType ) ),
                            Expression.Empty() ),
                        Expression.Constant( 2 ) ) ),

                // return this.ExecuteAll(query, queries, readerFactory, null, subQueries);
                Expression.Return( returnTarget, MakeMethodCall( oldProvider, "ExecuteAll", queryParam, queriesVar, readerFactoryVar, Expression.Constant( null, typeof( object[] ) ), subQueriesVar ) ),
                // --- Expression.Throw( Expression.New( typeof( Exception ).GetConstructor( new[] { typeof( string ) } ), Expression.Constant( "boo4" ) ) ),
                Expression.Label( returnTarget, Expression.Constant( null, typeof( IExecuteResult ) ) ) ),
            queryParam );
    }

    private static LambdaExpression CreateModifyQueriesMethod()
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var queryInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "QueryInfo", BindingFlags.NonPublic );
        var queriesParam = Expression.Parameter( queryInfoType.MakeArrayType(), "queries" );
        var modifyCommandParam = Expression.Parameter( typeof( ModifyCommandDelegate ), "modifyCommand" );

        var returnLabel = Expression.Label();
        var indexVar = Expression.Variable( typeof( int ), "index" );
        var queryVar = Expression.Variable( queryInfoType, "query" );

        return Expression.Lambda(
            Expression.Block(
                new[] { indexVar, queryVar },
                // var index = 0;
                Expression.Assign( indexVar, Expression.Constant( 0 ) ),
                Expression.Loop(
                    Expression.Block(
                        // if ( index >= queries.Length ) break;
                        Expression.IfThen(
                            Expression.GreaterThanOrEqual( indexVar, Expression.ArrayLength( queriesParam ) ),
                            Expression.Break( returnLabel ) ),
                        // query = queries[ index ];
                        Expression.Assign( queryVar, Expression.ArrayIndex( queriesParam, indexVar ) ),
                        // ModifyQuery( query );
                        Expression.Invoke( CreateModifyQueryMethod(), queryVar, modifyCommandParam ),
                        // ++ index;
                        Expression.PreIncrementAssign( indexVar ) ),
                    returnLabel ),
                queryVar ),
            queriesParam,
            modifyCommandParam );
    }

    private static LambdaExpression CreateModifySubQueriesMethod()
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var iCompiledSubQueryType = assembly.GetType( "System.Data.Linq.SqlClient.ICompiledSubQuery" );
        var compiledSubQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "CompiledSubQuery", BindingFlags.NonPublic );
        var queryInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "QueryInfo", BindingFlags.NonPublic );
        var subQueriesType = iCompiledSubQueryType.MakeArrayType();

        var queryInfoField = compiledSubQueryType.GetField( "queryInfo", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
        var subQueriesField = compiledSubQueryType.GetField( "subQueries", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );

        var subQueriesParam = Expression.Parameter( subQueriesType, "subQueries" );
        var modifyCommandParam = Expression.Parameter( typeof( ModifyCommandDelegate ), "modifyCommand" );

        var indexVar = Expression.Variable( typeof( int ), "index" );
        var subQueryVar = Expression.Variable( compiledSubQueryType, "subQuery" );
        var queryInfoVar = Expression.Variable( queryInfoType, "queryInfo" );
        var nestedSubQueriesVar = Expression.Variable( iCompiledSubQueryType.MakeArrayType(), "nestedSubQueries" );
        var modifySubQueriesDelegateType = typeof( Action<> ).MakeGenericType( iCompiledSubQueryType.MakeArrayType() );
        var modifySubQueriesDelegateVar = Expression.Variable( modifySubQueriesDelegateType, "modifySubQueries" );
        var loopTarget = Expression.Label();

        var innerLambda = Expression.Lambda(
            modifySubQueriesDelegateType,
            Expression.Block(
                new[] { indexVar },
            // var index = 0;
                Expression.Assign( indexVar, Expression.Constant( 0 ) ),
                Expression.Loop(
                    Expression.Block(
                        new[] { subQueryVar },
                        // if ( index >= subQueries.Length ) break;
                        Expression.IfThen(
                            Expression.GreaterThanOrEqual( indexVar, Expression.ArrayLength( subQueriesParam ) ),
                            Expression.Break( loopTarget ) ),
                        // var subQuery = subQueries[ index ];
                        Expression.Assign( subQueryVar, Expression.TypeAs( Expression.ArrayIndex( subQueriesParam, indexVar ), compiledSubQueryType ) ),
                        // if ( subQuery != null )
                        Expression.IfThen(
                            Expression.NotEqual( subQueryVar, Expression.Constant( null ) ),
                            Expression.Block(
                                new[] { queryInfoVar, nestedSubQueriesVar },
                                // var queryInfo = subQuery.queryInfo;
                                Expression.Assign( queryInfoVar, Expression.MakeMemberAccess( subQueryVar, queryInfoField ) ),
                                // ModifyQuery( queryInfo, modifyCommand );
                                Expression.Invoke( CreateModifyQueryMethod(), queryInfoVar, modifyCommandParam ),
                                // var nestedSubQueries = subQuery.subQueries;
                                Expression.Assign( nestedSubQueriesVar, Expression.MakeMemberAccess( subQueryVar, subQueriesField ) ),
                                // if ( nestedSubQueries != null )
                                Expression.IfThen(
                                    Expression.NotEqual( nestedSubQueriesVar, Expression.Constant( null ) ),
                                    // ModifySubQueries( nestedSubQueries, modifyCommand );
                                    Expression.Invoke( modifySubQueriesDelegateVar, nestedSubQueriesVar ) ) ) ),
                        // ++ index;
                        Expression.PreIncrementAssign( indexVar ) ),
                    loopTarget ) ),
                subQueriesParam );

        return Expression.Lambda(
            Expression.Block(
                new[] { modifySubQueriesDelegateVar },
                Expression.Assign( modifySubQueriesDelegateVar, innerLambda ),
                Expression.Invoke( modifySubQueriesDelegateVar, subQueriesParam ) ),
            subQueriesParam,
            modifyCommandParam );
    }

    private static LambdaExpression CreateModifyQueryMethod()
    {
        var assembly = typeof( SqlProvider ).Assembly;

        var queryInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider" ).GetNestedType( "QueryInfo", BindingFlags.NonPublic );
        var sqlParameterInfoType = assembly.GetType( "System.Data.Linq.SqlClient.SqlParameterInfo" );
        var sqlParameterType = assembly.GetType( "System.Data.Linq.SqlClient.SqlParameter" );

        var queryInfoParam = Expression.Parameter( queryInfoType, "query" );
        var modifyCommandParam = Expression.Parameter( typeof( ModifyCommandDelegate ), "modifyCommand" );

        var commandTextField = queryInfoType.GetField( "commandText", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
        var parametersField = queryInfoType.GetField( "parameters", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );

        var commandTextVar = Expression.Variable( typeof( string ), "commandText" );
        var parameterInfosVar = Expression.Variable( typeof( ReadOnlyCollection<> ).MakeGenericType( sqlParameterInfoType ), "parameterInfos" );
        var parametersVar = Expression.Variable( typeof( Dictionary<string, object> ), "parameters" );
        var indexVar = Expression.Variable( typeof( int ), "index" );
        var paramInfoVar = Expression.Variable( sqlParameterInfoType, "paramInfo" );
        var paramVar = Expression.Variable( sqlParameterType, "param" );
        var loopTarget = Expression.Label();

        return Expression.Lambda(
            Expression.IfThen(
                Expression.NotEqual( modifyCommandParam, Expression.Constant( null ) ),
                Expression.Block(
                    new[] { commandTextVar, parameterInfosVar, parametersVar, indexVar },
                    // var commandText = queryInfo.CommandText;
                    Expression.Assign( commandTextVar, Expression.MakeMemberAccess( queryInfoParam, commandTextField ) ),
                    // var parameterInfos = queryInfo.Parameters;
                    Expression.Assign( parameterInfosVar, Expression.MakeMemberAccess( queryInfoParam, parametersField ) ),
                    // var parameters = new Dictionary<string, object>();
                    Expression.Assign( parametersVar, Expression.New( typeof( Dictionary<string, object> ) ) ),
                    // var index = 0;
                    Expression.Assign( indexVar, Expression.Constant( 0 ) ),
                    Expression.Loop(
                        Expression.Block(
                            new[] { paramInfoVar, paramVar },
                            // if ( index >= parameterInfos.Count ) break;
                            Expression.IfThen(
                                Expression.GreaterThanOrEqual( indexVar, Expression.Property( parameterInfosVar, "Count" ) ),
                                Expression.Break( loopTarget ) ),
                            // var paramInfo = parameterInfos[ index ];
                            Expression.Assign( paramInfoVar, MakeMethodCall( parameterInfosVar, "get_Item", indexVar ) ),
                            // var param = paramInfo.Parameter;
                            Expression.Assign( paramVar, MakeMethodCall( paramInfoVar, "get_Parameter" ) ),
                            // parameters[ param.Name ] = paramInfo.Value;
                            MakeMethodCall( parametersVar, "set_Item", MakeMethodCall( paramVar, "get_Name" ), MakeMethodCall( paramInfoVar, "get_Value" ) ),
                            // ++ index;
                            Expression.PreIncrementAssign( indexVar ) ),
                        loopTarget ),
                    // queryInfo.CommandText = modifyCommand( commandText, parameters );
                    Expression.Assign( Expression.MakeMemberAccess( queryInfoParam, commandTextField ), Expression.Invoke( modifyCommandParam, commandTextVar, parametersVar ) ) ) ),
            queryInfoParam,
            modifyCommandParam );
    }

    internal interface IProviderProxy
    {
        DataContextInterceptor Interceptor { get; }
        object OldProvider { get; }
    }

    public class ProviderProxy : RealProxy, IRemotingTypeInfo, IProviderProxy
    {
        public DataContextInterceptor Interceptor { get; private set; }
        public object OldProvider { get; private set; }

        internal ProviderProxy( DataContextInterceptor extender, object oldProvider )
            : base( typeof( ContextBoundObject ) )
        {
            this.Interceptor = extender;
            this.OldProvider = oldProvider;
        }

        public override IMessage Invoke( IMessage msg )
        {
            var call = msg as IMethodCallMessage;

            if ( call != null && OldProvider != null )
            {
                try
                {
                    if ( call.MethodBase.DeclaringType.Name == "IProvider" && call.MethodBase.DeclaringType.IsInterface )
                    {
                        Interceptor.providerType = call.MethodBase.DeclaringType;

                        switch ( call.MethodName )
                        {
                            case "Compile": return new ReturnMessage( Interceptor.Compile( call.Args.Cast<Expression>().First() ), null, 0, null, call );
                            case "Execute": return new ReturnMessage( Interceptor.Execute( call.Args.Cast<Expression>().First() ), null, 0, null, call );
                        }
                    }

                    return new ReturnMessage( call.MethodBase.Invoke( OldProvider, call.Args ), null, 0, null, call );
                }
                catch ( TargetInvocationException e )
                {
                    return new ReturnMessage( e.InnerException, call );
                }
            }

            throw new NotImplementedException();
        }

        public bool CanCastTo( Type fromType, object o )
        {
            return true;
        }

        public string TypeName
        {
            get { return this.GetType().Name; }
            set { }
        }
    }
}

To use it, simply call DataContextInterceptor.Intercept on your DataContext before use (passing in a delegate to modify command text as needed):

DataContextInterceptor.Intercept( this, ( c, p ) => c.Replace( "[t0].[cName]", "NULL" ).Dump( "test") );

 

As it’s still calling private methods, you’ll need to be running in a full-trust environment.

Last year I posted this article about modifying LINQ to SQL (L2S) command text.  It was slightly evil in that it called private methods inside L2S, and it did it through reflection.

I have an alternative version that does the same thing through a pre-compiled expression tree, which I’ll post soon.  Until then, here’s an update to the original reflection-based interceptor that properly handles sub-queries:

    public delegate string ModifyCommandDelegate( string commandText, IDictionary<string, object> parameters );

    public class ReflectionDataContextInterceptor
    {
        private DataContext dc;
        private object oldProvider;
        private Type providerType;
        private ModifyCommandDelegate modifyCommand;

        public static TDataContext Intercept<TDataContext>( TDataContext dc, ModifyCommandDelegate modifyCommand )
            where TDataContext : DataContext
        {
            new ReflectionDataContextInterceptor( dc, modifyCommand );

            return dc;
        }

        public ReflectionDataContextInterceptor( DataContext dc, ModifyCommandDelegate modifyCommand )
        {
            this.dc = dc;
            this.modifyCommand = modifyCommand;

            FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
            var existingProvider = providerField.GetValue( dc );

            if ( existingProvider is IProviderProxy )
            {
//                System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} already intercepted", dc.GetHashCode() ) );
            }
            else
            {
                oldProvider = existingProvider;

                var proxy = new ProviderProxy( this, oldProvider ).GetTransparentProxy();

                providerField.SetValue( dc, proxy );

//                System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} intercepted", dc.GetHashCode() ) );
            }
        }

        public static MethodInfo GetMethod( Type type, string methodName, params object[] args )
        {
            var hasNullArgs = args.Any( a => a == null );

            var method = hasNullArgs
                ? type.GetMethods( BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic )
                    .FirstOrDefault( m => m.Name == methodName )
                : null;

            if ( method == null && !hasNullArgs )
            {
                method = type.GetMethod(
                    methodName,
                    BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                    null,
                    args.Select( a => a.GetType() ).ToArray(),
                    null );
            }

            return method;
        }

        private object Invoke( object instance, string methodName, params object[] args )
        {
            var type = instance.GetType();
            var method = GetMethod( type, methodName, args );

            return ( method != null )
                ? method.Invoke( instance, args )
                : null;
        }

        private object Invoke( Type type, string methodName, params object[] args )
        {
            var method = type.GetMethod( methodName, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic );

            return method.Invoke( null, args );
        }

        protected object CompileImpl( Expression query )
        {
            try
            {
                var assembly = typeof( SqlProvider ).Assembly;

                /*
                this.CheckDispose();
                this.CheckInitialized();
                if ( query == null )
                {
                    throw Error.ArgumentNull( "query" );
                }
                */

                // this.InitializeProviderMode();
                Invoke( oldProvider, "InitializeProviderMode" );

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

                // QueryInfo[] queries = this.BuildQuery( query, annotations );
                var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

                var info = ModifyQueries( (IEnumerable)queries );

                // this.CheckSqlCompatibility( queries, annotations );
                Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

                LambdaExpression expression = query as LambdaExpression;

                if ( expression != null )
                {
                    query = expression.Body;
                }

                // IObjectReaderFactory readerFactory = null;
                object readerFactory = null;

                // ICompiledSubQuery[] subQueries = null;
                object subQueries = null;

                // QueryInfo info = queries[ queries.Length - 1 ];
                // info defined above

                var resultShape = (int)Invoke( info, "get_ResultShape" );

                // if ( info.ResultShape == ResultShape.Singleton )
                if ( resultShape == 1 /* Singleton */ )
                {
                    // subQueries = this.CompileSubQueries( info.Query );
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory( info.Query, info.ResultType );
                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
                }
                // else if ( info.ResultShape == ResultShape.Sequence )
                else if ( resultShape == 2 /* Sequence */ )
                {
                    // subQueries = this.CompileSubQueries( info.Query );
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory( info.Query, TypeSystem.GetElementType( info.ResultType ) );
                    var resultType = Invoke( info, "get_ResultType" );
                    var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                    var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
                }

                FieldInfo providerField = typeof( DataContext ).GetField( "provider", BindingFlags.Instance | BindingFlags.NonPublic );
                providerField.SetValue( dc, oldProvider );

                //            System.Diagnostics.Trace.WriteLine( string.Format( "DataContext {0} interceptor released (compiled query)", dc.GetHashCode() ) );

                // return new CompiledQuery( this, query, queries, readerFactory, subQueries );
                var compiledQueryType = assembly.GetType( "System.Data.Linq.SqlClient.SqlProvider+CompiledQuery" );

                return Activator.CreateInstance(
                    compiledQueryType,
                    BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
                    null,
                    new object[] { oldProvider, query, queries, readerFactory, subQueries },
                    null );
            }
            catch ( TargetInvocationException ex )
            {
                throw ex.InnerException;
            }
        }

        protected internal virtual IExecuteResult ExecuteImpl( Expression query )
        {
            try
            {
                var assembly = typeof( SqlProvider ).Assembly;

                /*
                this.CheckDispose();
                this.CheckInitialized();
                this.CheckNotDeleted();
                if (query == null)
                {
                    throw Error.ArgumentNull("query");
                }
                */

                // this.InitializeProviderMode();
                Invoke( oldProvider, "InitializeProviderMode" );

                // query = Funcletizer.Funcletize(query);
                var funcletizerType = assembly.GetType( "System.Data.Linq.SqlClient.Funcletizer" );
                query = (Expression)Invoke( funcletizerType, "Funcletize", query );

                // if ( this.EnableCacheLookup )
                if ( (bool)Invoke( oldProvider, "get_EnableCacheLookup" ) )
                {
                    // IExecuteResult cachedResult = this.GetCachedResult(query);
                    object cachedResult = Invoke( oldProvider, "GetCachedResult", query );

                    if ( cachedResult != null )
                    {
                        // return cachedResult;
                        return (IExecuteResult)cachedResult;
                    }
                }

                // SqlNodeAnnotations annotations = new SqlNodeAnnotations();
                var annotations = Activator.CreateInstance( assembly.GetType( "System.Data.Linq.SqlClient.SqlNodeAnnotations" ) );

                // QueryInfo[] queries = this.BuildQuery(query, annotations);
                var queries = Invoke( oldProvider, "BuildQuery", query, annotations );

                var info = ModifyQueries( (IEnumerable)queries );

                // this.CheckSqlCompatibility(queries, annotations);
                Invoke( oldProvider, "CheckSqlCompatibility", queries, annotations );

                LambdaExpression expression = query as LambdaExpression;

                if ( expression != null )
                {
                    query = expression.Body;
                }

                // IObjectReaderFactory readerFactory = null;
                object readerFactory = null;

                // ICompiledSubQuery[] subQueries = null;
                object subQueries = null;

                // QueryInfo info = queries[queries.Length - 1];
                // info defined above

                var resultShape = (int)Invoke( info, "get_ResultShape" );

                // if (info.ResultShape == ResultShape.Singleton)
                if ( resultShape == 1 /* Singleton */ )
                {
                    // subQueries = this.CompileSubQueries(info.Query);
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory(info.Query, info.ResultType);
                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), Invoke( info, "get_ResultType" ) );
                }
                // else if (info.ResultShape == ResultShape.Sequence)
                else if ( resultShape == 2 /* Sequence */ )
                {
                    // subQueries = this.CompileSubQueries(info.Query);
                    subQueries = Invoke( oldProvider, "CompileSubQueries", Invoke( info, "get_Query" ) );

                    ModifySubQueries( (IEnumerable)subQueries );

                    // readerFactory = this.GetReaderFactory(info.Query, TypeSystem.GetElementType(info.ResultType));
                    var resultType = Invoke( info, "get_ResultType" );
                    var typeSystemType = assembly.GetType( "System.Data.Linq.SqlClient.TypeSystem" );
                    var elementType = Invoke( typeSystemType, "GetElementType", resultType );

                    readerFactory = Invoke( oldProvider, "GetReaderFactory", Invoke( info, "get_Query" ), elementType );
                }

                // return this.ExecuteAll(query, queries, readerFactory, null, subQueries);
                return (IExecuteResult)Invoke( oldProvider, "ExecuteAll", query, queries, readerFactory, null, subQueries );
            }
            catch ( TargetInvocationException ex )
            {
                throw ex.InnerException;
            }
        }

        private object ModifyQueries( IEnumerable queries )
        {
            object lastQuery = null;

            foreach ( var q in queries )
            {
                lastQuery = q;

                ModifyQuery( q );
            }

            return lastQuery;
        }

        private void ModifySubQueries( IEnumerable subQueries )
        {
            if ( subQueries == null ) return;

            foreach ( var sq in subQueries )
            {
                var queryInfoField = sq.GetType().GetField( "queryInfo", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                var queryInfo = queryInfoField.GetValue( sq );
                ModifyQuery( queryInfo );

                var subQueriesField = sq.GetType().GetField( "subQueries", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                var nestedSubQueries = subQueriesField.GetValue( sq );

                if ( nestedSubQueries != null ) ModifySubQueries( (IEnumerable)nestedSubQueries );
            }
        }

        private void ModifyQuery( object q /* QueryInfo */ )
        {
            var commandTextField = q.GetType().GetField( "commandText", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
            var parametersField = q.GetType().GetField( "parameters", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );

            var commandText = (string)commandTextField.GetValue( q );
            var parameterInfos = parametersField.GetValue( q );
            var parameters = new Dictionary<string, object>();

            foreach ( var p in (IEnumerable)parameterInfos )
            {
                var param = Invoke( p, "get_Parameter" );
                var name = (string)Invoke( param, "get_Name" );

                parameters[ name ] = Invoke( p, "get_Value" );
            }

            var modifiedCommandText = modifyCommand( commandText, parameters );

//            System.Diagnostics.Trace.WriteLine( modifiedCommandText + "\n---" );

            commandTextField.SetValue( q, modifiedCommandText );
        }
/*
        protected internal DbCommand GetCommandImpl( Expression query )
        {
            return (DbCommand)Invoke( oldProvider, "GetCommand", query );
        }

        protected internal string GetQueryTextImpl( Expression query )
        {
            return (string)Invoke( oldProvider, "GetQueryText", query );
        }
*/
        internal interface IProviderProxy
        {
            ReflectionDataContextInterceptor Interceptor { get; }
            object OldProvider { get; }
        }

        public class ProviderProxy : RealProxy, IRemotingTypeInfo, IProviderProxy
        {
            public ReflectionDataContextInterceptor Interceptor { get; private set; }
            public object OldProvider { get; private set; }

            internal ProviderProxy( ReflectionDataContextInterceptor extender, object oldProvider )
                : base( typeof( ContextBoundObject ) )
            {
                this.Interceptor = extender;
                this.OldProvider = oldProvider;
            }

            public override IMessage Invoke( IMessage msg )
            {
                if ( msg is IMethodCallMessage )
                {
                    IMethodCallMessage call = (IMethodCallMessage)msg;
                    MethodInfo mi = null;

                    if ( call.MethodBase.DeclaringType.Name == "IProvider" && call.MethodBase.DeclaringType.IsInterface )
                    {
                        Interceptor.providerType = call.MethodBase.DeclaringType;

                        mi = ReflectionDataContextInterceptor.GetMethod( typeof( ReflectionDataContextInterceptor ), call.MethodBase.Name + "Impl", call.Args );

                        if ( mi == null && OldProvider != null )
                        {
                            mi = ReflectionDataContextInterceptor.GetMethod( call.MethodBase.DeclaringType, call.MethodBase.Name, call.Args );

                            try
                            {
                                return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( this.Interceptor, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }
//                    else mi = typeof( ReflectionDataContextInterceptor ).GetMethod( call.MethodBase.Name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic );
                    else
                    {
                        mi = ReflectionDataContextInterceptor.GetMethod( OldProvider.GetType(), call.MethodBase.Name, call.Args );

                        if ( mi != null )
                        {
                            try
                            {
                                return new ReturnMessage( mi.Invoke( OldProvider, call.Args ), null, 0, null, call );
                            }
                            catch ( TargetInvocationException e )
                            {
                                return new ReturnMessage( e.InnerException, call );
                            }
                        }
                    }

                    throw new NotImplementedException(
                        string.Format( "Method not found: {0}( {1} )",
                            call.MethodBase.Name,
                            string.Join( ", ", call.Args.Select( a => Convert.ToString( a ) ) ) ) );
                }

                throw new NotImplementedException();
            }

            public bool CanCastTo( Type fromType, object o )
            {
                return true;
            }

            public string TypeName
            {
                get { return this.GetType().Name; }
                set { }
            }
        }
    }

It’s still evil, but works great Smile

Follow

Get every new post delivered to your Inbox.

Join 41 other followers