Skip to content

Commit

Permalink
Fix to #7476 - Query : client-side result operators may cause extensi…
Browse files Browse the repository at this point in the history
…ve client-side evaluation for queries with GroupJoins (and/or optional navigations) in a subquery

When we translate GroupJoin-SelectMany-DefaultIfEmpty we modify query model - we remove additional from clause containing DefaultIfEmpty subquery.
However, if we later discover that the subquery (or a subquery higher in the stack) needs client evaluation - e.g. by having client result operator, or calling a client-side method - we need to re-visit this subquery with client evaluation in mind.

Problem was that now the query model has been modified - in this specific case SelectMany-DefaultIfEmpty clause has been removed, so we can no longer recognize that this pattern can be simplified and we introduce client GroupJoin.

Fix is to save the structure of query model (and all subquery models that it contains) before we visit it - and then, if client evaluation is needed, restore the query model to it's original shape.
  • Loading branch information
maumar committed Jan 27, 2017
1 parent b359086 commit 098ea79
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.ExpressionTranslators;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
Expand Down Expand Up @@ -914,6 +916,9 @@ var queriesProjectionCountMapping
= _queryModelVisitor.Queries
.ToDictionary(k => k, s => s.Projection.Count);

var queryModelMapping = new Dictionary<QueryModel, QueryModel>();
subQueryModel.PopulateQueryModelMapping(queryModelMapping);

queryModelVisitor.VisitSubQueryModel(subQueryModel);

if (queryModelVisitor.Queries.Count == 1
Expand All @@ -932,6 +937,8 @@ var queriesProjectionCountMapping

return selectExpression;
}

subQueryModel.RecreateQueryModelFromMapping(queryModelMapping);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,12 @@ var subQueryModelVisitor
= (RelationalQueryModelVisitor)QueryCompilationContext
.CreateQueryModelVisitor(this);

subQueryModelVisitor.VisitSubQueryModel(subQueryExpression.QueryModel);
var subQueryModel = subQueryExpression.QueryModel;

var queryModelMapping = new Dictionary<QueryModel, QueryModel>();
subQueryModel.PopulateQueryModelMapping(queryModelMapping);

subQueryModelVisitor.VisitSubQueryModel(subQueryModel);

if (subQueryModelVisitor.Queries.Count == 1
&& !subQueryModelVisitor.RequiresClientEval
Expand Down Expand Up @@ -976,6 +981,8 @@ var newExpression
}
}

subQueryModel.RecreateQueryModelFromMapping(queryModelMapping);

return expression;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3722,5 +3722,174 @@ public virtual void Where_predicate_on_optional_reference_navigation()
}
}
}

[ConditionalFact]
public virtual void GroupJoin_in_subquery_with_client_result_operator()
{
List<string> expected;
using (var context = CreateContext())
{
expected = (from l1 in context.LevelOne.ToList()
where (from l1_inner in context.LevelOne.ToList()
join l2_inner in context.LevelTwo.ToList() on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping
from l2_inner in grouping.DefaultIfEmpty()
select l1_inner).Distinct().Count() > 7
where l1.Id < 3
select l1.Name).ToList();
}

ClearLog();

using (var context = CreateContext())
{
var query = from l1 in context.LevelOne
where (from l1_inner in context.LevelOne
join l2_inner in context.LevelTwo on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping
from l2_inner in grouping.DefaultIfEmpty()
select l1_inner).Distinct().Count() > 7
where l1.Id < 3
select l1.Name;

var result = query.ToList();

Assert.Equal(expected.Count, result.Count);
foreach (var resultItem in result)
{
Assert.True(expected.Contains(resultItem));
}
}
}

[ConditionalFact]
public virtual void GroupJoin_in_subquery_with_client_projection()
{
List<string> expected;
using (var context = CreateContext())
{
expected = (from l1 in context.LevelOne.ToList()
where (from l1_inner in context.LevelOne.ToList()
join l2_inner in context.LevelTwo.ToList() on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping
from l2_inner in grouping.DefaultIfEmpty()
select l1_inner).Distinct().Count() > 7
where l1.Id < 3
select l1.Name).ToList();
}

ClearLog();

using (var context = CreateContext())
{
var query = from l1 in context.LevelOne
where (from l1_inner in context.LevelOne
join l2_inner in context.LevelTwo on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping
from l2_inner in grouping.DefaultIfEmpty()
select ClientStringMethod(l1_inner.Name)).Count() > 7
where l1.Id < 3
select l1.Name;

var result = query.ToList();

Assert.Equal(expected.Count, result.Count);
foreach (var resultItem in result)
{
Assert.True(expected.Contains(resultItem));
}
}
}

[ConditionalFact]
public virtual void GroupJoin_in_subquery_with_client_projection_nested1()
{
List<string> expected;
using (var context = CreateContext())
{
expected = (from l1_outer in context.LevelOne.ToList()
where (from l1_middle in context.LevelOne.ToList()
join l2_middle in context.LevelTwo.ToList() on l1_middle.Id equals l2_middle.Level1_Optional_Id into grouping_middle
from l2_middle in grouping_middle.DefaultIfEmpty()
where (from l1_inner in context.LevelOne.ToList()
join l2_inner in context.LevelTwo.ToList() on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping_inner
from l2_inner in grouping_inner.DefaultIfEmpty()
select ClientStringMethod(l1_inner.Name)).Count() > 7
select l1_middle).Take(10).Count() > 4
where l1_outer.Id < 2
select l1_outer.Name).ToList();
}

ClearLog();

using (var context = CreateContext())
{
var query = from l1_outer in context.LevelOne
where (from l1_middle in context.LevelOne
join l2_middle in context.LevelTwo on l1_middle.Id equals l2_middle.Level1_Optional_Id into grouping_middle
from l2_middle in grouping_middle.DefaultIfEmpty()
where (from l1_inner in context.LevelOne
join l2_inner in context.LevelTwo on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping_inner
from l2_inner in grouping_inner.DefaultIfEmpty()
select ClientStringMethod(l1_inner.Name)).Count() > 7
select l1_middle).Take(10).Count() > 4
where l1_outer.Id < 2
select l1_outer.Name;

var result = query.ToList();

Assert.Equal(expected.Count, result.Count);
foreach (var resultItem in result)
{
Assert.True(expected.Contains(resultItem));
}
}
}

[ConditionalFact]
public virtual void GroupJoin_in_subquery_with_client_projection_nested2()
{
List<string> expected;
using (var context = CreateContext())
{
expected = (from l1_outer in context.LevelOne.ToList()
where (from l1_middle in context.LevelOne.ToList()
join l2_middle in context.LevelTwo.ToList() on l1_middle.Id equals l2_middle.Level1_Optional_Id into grouping_middle
from l2_middle in grouping_middle.DefaultIfEmpty()
where (from l1_inner in context.LevelOne.ToList()
join l2_inner in context.LevelTwo.ToList() on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping_inner
from l2_inner in grouping_inner.DefaultIfEmpty()
select l1_inner.Name).Count() > 7
select ClientStringMethod(l1_middle.Name)).Count() > 4
where l1_outer.Id < 2
select l1_outer.Name).ToList();
}

ClearLog();

using (var context = CreateContext())
{
var query = from l1_outer in context.LevelOne
where (from l1_middle in context.LevelOne
join l2_middle in context.LevelTwo on l1_middle.Id equals l2_middle.Level1_Optional_Id into grouping_middle
from l2_middle in grouping_middle.DefaultIfEmpty()
where (from l1_inner in context.LevelOne
join l2_inner in context.LevelTwo on l1_inner.Id equals l2_inner.Level1_Optional_Id into grouping_inner
from l2_inner in grouping_inner.DefaultIfEmpty()
select l1_inner.Name).Count() > 7
select ClientStringMethod(l1_middle.Name)).Count() > 4
where l1_outer.Id < 2
select l1_outer.Name;

var result = query.ToList();

Assert.Equal(expected.Count, result.Count);
foreach (var resultItem in result)
{
Assert.True(expected.Contains(resultItem));
}
}
}

private static string ClientStringMethod(string argument)
{
return argument;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors;
Expand All @@ -23,6 +24,97 @@ public static class QueryModelExtensions
public static string Print([NotNull] this QueryModel queryModel)
=> new QueryModelPrinter().Print(queryModel);

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public static Dictionary<QueryModel, QueryModel> PopulateQueryModelMapping(
[NotNull] this QueryModel queryModel,
[NotNull] Dictionary<QueryModel, QueryModel> mapping)
{
var mappingPopulatingVisitor = new QueryModelMappingPopulatingVisitor(mapping);
mappingPopulatingVisitor.Visit(new SubQueryExpression(queryModel));

return mapping;
}

private class QueryModelMappingPopulatingVisitor : ExpressionVisitorBase
{
private readonly Dictionary<QueryModel, QueryModel> _mapping;

public QueryModelMappingPopulatingVisitor(Dictionary<QueryModel, QueryModel> mapping)
{
_mapping = mapping;
}

protected override Expression VisitSubQuery(SubQueryExpression expression)
{
var queryModel = expression.QueryModel;

var newQueryModel = new QueryModel(queryModel.MainFromClause, queryModel.SelectClause);
ShallowCopy(queryModel, newQueryModel);

_mapping.Add(queryModel, newQueryModel);

queryModel.TransformExpressions(Visit);

return expression;
}
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public static QueryModel RecreateQueryModelFromMapping(
[NotNull] this QueryModel queryModel,
[NotNull] Dictionary<QueryModel, QueryModel> mapping)
{
var recreatingVisitor = new QueryModelRecreatingVisitor(mapping);
var resultExpression = recreatingVisitor.Visit(new SubQueryExpression(queryModel));

return ((SubQueryExpression)resultExpression).QueryModel;
}

private class QueryModelRecreatingVisitor : ExpressionVisitorBase
{
private readonly Dictionary<QueryModel, QueryModel> _mapping;

public QueryModelRecreatingVisitor(Dictionary<QueryModel, QueryModel> mapping)
{
_mapping = mapping;
}

protected override Expression VisitSubQuery(SubQueryExpression expression)
{
var queryModel = expression.QueryModel;

var originalQueryModel = _mapping[queryModel];
ShallowCopy(originalQueryModel, queryModel);

queryModel.TransformExpressions(Visit);

return expression;
}
}

private static void ShallowCopy(QueryModel sourceQueryModel, QueryModel targetQueryModel)
{
targetQueryModel.BodyClauses.Clear();
foreach (var bodyClause in sourceQueryModel.BodyClauses)
{
targetQueryModel.BodyClauses.Add(bodyClause);
}

targetQueryModel.ResultOperators.Clear();
foreach (var resultOperator in sourceQueryModel.ResultOperators)
{
targetQueryModel.ResultOperators.Add(resultOperator);
}

targetQueryModel.ResultTypeOverride = sourceQueryModel.ResultTypeOverride;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down
Loading

0 comments on commit 098ea79

Please sign in to comment.