From 6becd7ec5b6ee096f3b0e60a3851dab395b423c9 Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Wed, 27 Apr 2016 18:09:22 -0700 Subject: [PATCH] Fix to #5179 - Query: Select Where Navigation returns wrong results on relational providers Problem was that when creating groups in GroupJoin we would only look at the inner key selector to differentiate between groups. However in some cases (specifically 1-Many navigations coming from the many side - e.g. order.Customer) outer element constantly changes, while inner element stays the same. If later in the pipeline outer elements are being used, e.g. in projection, we get incorrect results. Fix is to also look at changes between outer elements and create a new group every time either outer element or a inner key changes. Also as part of this change we improve result verification for the navigation tests (previously we would only verify that the result count is as expected). --- .../Internal/BufferedEntityShaper`.cs | 9 + .../Internal/BufferedOffsetEntityShaper.cs | 3 + .../ExpressionVisitors/Internal/IShaper.cs | 1 + .../Internal/QueryFlattener.cs | 52 ++ .../Internal/UnbufferedEntityShaper.cs | 3 + .../Internal/ValueBufferShaper.cs | 3 + .../Query/QueryMethodProvider.cs | 11 + .../QueryNavigationsTestBase.cs | 802 +++++++++--------- .../TestModels/Northwind/NorthwindData.cs | 61 +- .../ChangeTracking/Internal/IIdentityMap.cs | 2 + .../ChangeTracking/Internal/IStateManager.cs | 2 + .../ChangeTracking/Internal/IdentityMap.cs | 15 + .../ChangeTracking/Internal/StateManager.cs | 3 + .../Query/Internal/IQueryBuffer.cs | 6 + .../Query/Internal/QueryBuffer.cs | 6 + .../QueryNavigationsSqlServerTest.cs | 14 +- .../DbContextTest.cs | 5 + 17 files changed, 565 insertions(+), 433 deletions(-) diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedEntityShaper`.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedEntityShaper`.cs index 6bf7161d4fa..8cb9ed75c11 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedEntityShaper`.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedEntityShaper`.cs @@ -25,6 +25,15 @@ public BufferedEntityShaper( public override Type Type => typeof(TEntity); + public virtual object GetKey(QueryContext queryContext, ValueBuffer valueBuffer) + { + return queryContext.QueryBuffer.GetEntityKey( + Key, + new EntityLoadInfo(valueBuffer, Materializer), + queryStateManager: IsTrackingQuery, + throwOnNullKey: !AllowNullResult); + } + public virtual TEntity Shape(QueryContext queryContext, ValueBuffer valueBuffer) { Debug.Assert(queryContext != null); diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedOffsetEntityShaper.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedOffsetEntityShaper.cs index bfcad876ef3..fc53c684e96 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedOffsetEntityShaper.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/BufferedOffsetEntityShaper.cs @@ -22,6 +22,9 @@ public BufferedOffsetEntityShaper( { } + public override object GetKey(QueryContext queryContext, ValueBuffer valueBuffer) + => base.GetKey(queryContext, valueBuffer.WithOffset(ValueBufferOffset)); + public override TEntity Shape(QueryContext queryContext, ValueBuffer valueBuffer) => base.Shape(queryContext, valueBuffer.WithOffset(ValueBufferOffset)); diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/IShaper.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/IShaper.cs index be04709b441..1b0c8a1db8c 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/IShaper.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/IShaper.cs @@ -10,6 +10,7 @@ namespace Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal { public interface IShaper { + object GetKey([NotNull] QueryContext queryContext, ValueBuffer valueBuffer); T Shape([NotNull] QueryContext queryContext, ValueBuffer valueBuffer); bool IsShaperForQuerySource([NotNull] IQuerySource querySource); void SaveAccessorExpression([NotNull] QuerySourceMapping querySourceMapping); diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/QueryFlattener.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/QueryFlattener.cs index b52a3c8db9d..7e6c378896c 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/QueryFlattener.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/QueryFlattener.cs @@ -170,6 +170,13 @@ public CompositeShaper( _materializer = materializer; } + public object GetKey(QueryContext queryContext, ValueBuffer valueBuffer) + { + return new CompositeKey( + _outerShaper.GetKey(queryContext, valueBuffer), + _innerShaper.GetKey(queryContext, valueBuffer)); + } + public TResult Shape(QueryContext queryContext, ValueBuffer valueBuffer) => _materializer( _outerShaper.Shape(queryContext, valueBuffer), @@ -192,5 +199,50 @@ public override Expression GetAccessorExpression(IQuerySource querySource) => _outerShaper.GetAccessorExpression(querySource) ?? _innerShaper.GetAccessorExpression(querySource); } + + private class CompositeKey + { + private readonly object _outerKey; + private readonly object _innerKey; + + public CompositeKey(object outerKey, object innerKey) + { + _outerKey = outerKey; + _innerKey = innerKey; + } + + public override bool Equals(object obj) + { + var other = obj as CompositeKey; + if (other == null) + { + return false; + } + + return _outerKey != null + ? _outerKey.Equals(other._outerKey) : other._outerKey == null + && _innerKey != null + ? _innerKey.Equals(other._innerKey) : other._innerKey == null; + } + + public override int GetHashCode() + { + unchecked + { + var hashCode = 0; + if (_outerKey != null) + { + hashCode = _outerKey.GetHashCode(); + } + + if (_innerKey != null) + { + hashCode += (hashCode * 397) ^ _innerKey.GetHashCode(); + } + + return hashCode; + } + } + } } } diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/UnbufferedEntityShaper.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/UnbufferedEntityShaper.cs index fadabdfdf09..517afa2b6ae 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/UnbufferedEntityShaper.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/UnbufferedEntityShaper.cs @@ -24,6 +24,9 @@ public UnbufferedEntityShaper( public override Type Type => typeof(TEntity); + public virtual object GetKey(QueryContext queryContext, ValueBuffer valueBuffer) => + queryContext.StateManager.TryGetEntryKey(Key, valueBuffer, !AllowNullResult); + public virtual TEntity Shape(QueryContext queryContext, ValueBuffer valueBuffer) { if (IsTrackingQuery) diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/ValueBufferShaper.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/ValueBufferShaper.cs index b46f8112be3..2b6e260d0a6 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/ValueBufferShaper.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/ExpressionVisitors/Internal/ValueBufferShaper.cs @@ -17,6 +17,9 @@ public ValueBufferShaper([NotNull] IQuerySource querySource) public override Type Type => typeof(ValueBuffer); + public virtual object GetKey(QueryContext queryContext, ValueBuffer valueBuffer) + => null; + public virtual ValueBuffer Shape(QueryContext queryContext, ValueBuffer valueBuffer) => valueBuffer; } diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/QueryMethodProvider.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/QueryMethodProvider.cs index c9a7d01a1a7..f419a7d9ffc 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/QueryMethodProvider.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/QueryMethodProvider.cs @@ -240,6 +240,7 @@ private static IEnumerable _GroupJoin( } else { + var currentOuterEntityKey = outerShaper.GetKey(queryContext, sourceEnumerator.Current); var currentGroupKey = innerKeySelector(inner); innerGroupJoinInclude?.Include(inner); @@ -255,6 +256,16 @@ private static IEnumerable _GroupJoin( break; } + if (currentOuterEntityKey != null) + { + var nextOuterEntityKey = outerShaper.GetKey(queryContext, sourceEnumerator.Current); + + if (!currentOuterEntityKey.Equals(nextOuterEntityKey)) + { + break; + } + } + inner = innerShaper.Shape(queryContext, sourceEnumerator.Current); if (inner == null) diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryNavigationsTestBase.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryNavigationsTestBase.cs index 8ff4ddbf25e..2c26b357049 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryNavigationsTestBase.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryNavigationsTestBase.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -21,15 +22,11 @@ public abstract class QueryNavigationsTestBase : IClassFixture() - where o.Customer.City == "Seattle" - select o).ToList(); - - Assert.Equal(14, orders.Count); - } + AssertQuery( + os => from o in os + where o.Customer.City == "Seattle" + select o, + entryCount: 14); } [ConditionalFact] @@ -81,85 +78,64 @@ from o2 in context.Set().Where(o => o.OrderID < 10400) [ConditionalFact] public virtual void Select_Where_Navigation_Client() { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - where o.Customer.IsLondon - select o).ToList(); - - Assert.Equal(46, orders.Count); - } + AssertQuery( + os => from o in os + where o.Customer.IsLondon + select o, + entryCount: 46); } [ConditionalFact] public virtual void Select_Where_Navigation_Deep() { - using (var context = CreateContext()) - { - var orderDetails - = (from od in context.Set() - where od.Order.Customer.City == "Seattle" - select od).Take(1).ToList(); + AssertQuery( + ods => (from od in ods + where od.Order.Customer.City == "Seattle" + orderby od.OrderID, od.ProductID + select od).Take(1), + asserter: (l2oItems, efItems) => + { + var matchingPairs = + from dynamic l2oItem in l2oItems + join dynamic efItem in efItems on new { l2oItem.OrderID, l2oItem.ProductID } equals new { efItem.OrderID, efItem.ProductID } + select new { l2oItem, efItem }; - Assert.Equal(1, orderDetails.Count); - } + Assert.Equal(matchingPairs.Count(), l2oItems.Count); + }, + entryCount: 1); } [ConditionalFact] public virtual void Select_Where_Navigation_Null() { - using (var context = CreateContext()) - { - var employees - = (from e in context.Set() - where e.Manager == null - select e).ToList(); - - Assert.Equal(1, employees.Count); - } + AssertQuery( + es => from e in es + where e.Manager == null + select e, + entryCount: 1); } [ConditionalFact] public virtual void Select_Where_Navigation_Null_Reverse() { - using (var context = CreateContext()) - { - var query = from e in context.Set() - where null == e.Manager - select e; - - var result = query.ToList(); - - Assert.Equal(1, result.Count); - } + AssertQuery( + es => from e in es + where null == e.Manager + select e, + entryCount: 1); } [ConditionalFact] public virtual void Select_Where_Navigation_Null_Deep() { - List expected; - using (var context = CreateContext()) - { - expected = context.Employees.Include(e => e.Manager.Manager).ToList() - .Where(e => e.Manager == null || e.Manager.Manager == null).ToList(); - } - - ClearLog(); - - using (var context = CreateContext()) - { - var employees - = (from e in context.Set() - where e.Manager.Manager == null - select e).ToList(); - - Assert.Equal(expected.Count, employees.Count); - foreach (var employee in employees) - { - Assert.True(expected.Select(e => e.EmployeeID).Contains(employee.EmployeeID)); - } - } + AssertQuery( + es => from e in es + where e.Manager.Manager == null + select e, + es => from e in es + where (e.Manager != null ? e.Manager.Manager : null) == null + select e, + entryCount: 6); } // issue 4539 @@ -181,65 +157,48 @@ from o2 in context.Set() [ConditionalFact] public virtual void Select_Where_Navigation_Included() { + Func, IQueryable> queryFunc = + os => from o in os.Include(o => o.Customer) + where o.Customer.City == "Seattle" + select o; + using (var context = CreateContext()) { - var query = from o in context.Set().Include(o => o.Customer) - where o.Customer.City == "Seattle" - select o; - - var result = query.ToList(); + var result = queryFunc(context.Orders).ToList(); Assert.Equal(14, result.Count); Assert.True(result.All(o => o.Customer != null)); } + + ClearLog(); + + AssertQuery( + queryFunc, + os => from o in os + where o.Customer.City == "Seattle" + select o, + entryCount: 15); } [ConditionalFact] public virtual void Singleton_Navigation_With_Member_Access() { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - where o.Customer.City == "Seattle" - where o.Customer.Phone != "555 555 5555" - select new { B = o.Customer.City }).ToList(); - - Assert.Equal(14, orders.Count); - Assert.True(orders.All(o => o.B != null)); - } + AssertQuery( + os => from o in os + where o.Customer.City == "Seattle" + where o.Customer.Phone != "555 555 5555" + select new { B = o.Customer.City }); } [ConditionalFact] public virtual void Select_Singleton_Navigation_With_Member_Access() { - List expected; - using (var context = CreateContext()) - { - expected = context.Orders.Include(o => o.Customer) - .ToList() - .Where(o => o.Customer?.City == "Seattle") - .Where(o => o.Customer?.Phone != "555 555 5555") - .ToList(); - } - - ClearLog(); - - using (var context = CreateContext()) - { - var query = from o in context.Set() - where o.Customer.City == "Seattle" - where o.Customer.Phone != "555 555 5555" - select new { A = o.Customer, B = o.Customer.City }; - - var result = query.ToList(); - - Assert.Equal(expected.Count, result.Count); - foreach (var resultElement in result) - { - Assert.True(expected.Any(e => e.CustomerID == resultElement.A.CustomerID && e.Customer?.City == resultElement.B)); - } - } + AssertQuery( + os => from o in os + where o.Customer.City == "Seattle" + where o.Customer.Phone != "555 555 5555" + select new { A = o.Customer, B = o.Customer.City }, + entryCount: 1); } [ConditionalFact] @@ -258,222 +217,160 @@ var orders } } - [ConditionalFact] - public virtual void Select_Where_Navigations() - { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - where (o.Customer.City == "Seattle") - && (o.Customer.Phone != "555 555 5555") - select o).ToList(); - - Assert.Equal(14, orders.Count); - } - } - [ConditionalFact] public virtual void Select_Where_Navigation_Multiple_Access() { - List expected; - using (var context = CreateContext()) - { - expected = context.Orders.Include(o => o.Customer).ToList() - .Where(o => o.Customer?.City == "Seattle" - && o.Customer?.Phone != "555 555 5555") - .Select(e => e.CustomerID) - .ToList(); - } - - ClearLog(); - - using (var context = CreateContext()) - { - var query = from o in context.Set() - where (o.Customer.City == "Seattle") - && (o.Customer.Phone != "555 555 5555") - select o; - - var result = query.ToList(); - - Assert.Equal(expected.Count, result.Count); - foreach (var resultElement in result) - { - expected.Contains(resultElement.CustomerID); - } - } + AssertQuery( + os => from o in os + where o.Customer.City == "Seattle" + && o.Customer.Phone != "555 555 5555" + select o, + entryCount: 14); } [ConditionalFact] public virtual void Select_Navigation() { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - select o.Customer).ToList(); - - Assert.Equal(830, orders.Count); - Assert.True(orders.All(o => o != null)); - } + AssertQuery( + os => from o in os + select o.Customer, + entryCount: 89); } [ConditionalFact] public virtual void Select_Navigations() { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - select new { A = o.Customer, B = o.Customer }).ToList(); - - Assert.Equal(830, orders.Count); - Assert.True(orders.All(o => (o.A != null) && (o.B != null))); - } + AssertQuery( + os => from o in os + select new { A = o.Customer, B = o.Customer }, + entryCount: 89); } [ConditionalFact] public virtual void Select_Navigations_Where_Navigations() { - using (var context = CreateContext()) - { - var orders - = (from o in context.Set() - where o.Customer.City == "Seattle" - where o.Customer.Phone != "555 555 5555" - select new { A = o.Customer, B = o.Customer }).ToList(); - - Assert.Equal(14, orders.Count); - Assert.True(orders.All(o => (o.A != null) && (o.B != null))); - } + AssertQuery( + os => from o in os + where o.Customer.City == "Seattle" + where o.Customer.Phone != "555 555 5555" + select new { A = o.Customer, B = o.Customer }, + entryCount: 1); } [ConditionalFact] public virtual void Select_collection_navigation_simple() { - using (var context = CreateContext()) - { - var query = from c in context.Customers - where c.CustomerID.StartsWith("A") - select new { c.Orders }; - - var results = query.ToList(); - - Assert.Equal(4, results.Count); - Assert.True(results.All(r => r.Orders.Count > 0)); - } + AssertQuery( + cs => from c in cs + where c.CustomerID.StartsWith("A") + select new { c.CustomerID, c.Orders }, + asserter: (l2oItems, efItems) => + { + foreach (var pair in + from dynamic l2oItem in l2oItems + join dynamic efItem in efItems on l2oItem.CustomerID equals efItem.CustomerID + select new { l2oItem, efItem }) + { + Assert.Equal(pair.l2oItem.Orders, pair.efItem.Orders); + } + }); } [ConditionalFact] public virtual void Select_collection_navigation_multi_part() { - using (var context = CreateContext()) - { - var query = from o in context.Orders - where o.CustomerID == "ALFKI" - select new { o.Customer.Orders }; - - var results = query.ToList(); - - Assert.Equal(6, results.Count); - Assert.True(results.All(r => r.Orders.Count > 0)); - } + AssertQuery( + os => from o in os + where o.CustomerID == "ALFKI" + select new { o.OrderID, o.Customer.Orders }, + asserter: (l2oItems, efItems) => + { + foreach (var pair in + from dynamic l2oItem in l2oItems + join dynamic efItem in efItems on l2oItem.OrderID equals efItem.OrderID + select new { l2oItem, efItem }) + { + Assert.Equal(pair.l2oItem.Orders, pair.efItem.Orders); + } + }); } [ConditionalFact] public virtual void Collection_select_nav_prop_any() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { Any = c.Orders.Any() }).ToList(); - - Assert.Equal(91, customers.Count); - Assert.Equal(89, customers.Count(c => c.Any)); - } + AssertQuery( + cs => from c in cs + select new { Any = c.Orders.Any() }, + cs => from c in cs + select new { Any = (c.Orders ?? new List()).Any()}); } [ConditionalFact] public virtual void Collection_select_nav_prop_predicate() { - using (var context = CreateContext()) - { - var result = context.Customers - .Select(c => c.Orders.Count > 0) - .ToList(); - } + AssertQuery( + cs => cs.Select(c => c.Orders.Count > 0), + cs => cs.Select(c => (c.Orders ?? new List()).Count > 0)); } [ConditionalFact] public virtual void Collection_where_nav_prop_any() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where c.Orders.Any() - select c).ToList(); + AssertQuery( + cs => from c in cs + where c.Orders.Any() + select c, - Assert.Equal(89, customers.Count); - } + cs => from c in cs + where (c.Orders ?? new List()).Any() + select c, + entryCount: 89); } [ConditionalFact] public virtual void Collection_where_nav_prop_any_predicate() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where c.Orders.Any(o => o.OrderID > 0) - select c).ToList(); - - Assert.Equal(89, customers.Count); - } + AssertQuery( + cs => from c in cs + where c.Orders.Any(o => o.OrderID > 0) + select c, + cs => from c in cs + where (c.Orders ?? new List()).Any(o => o.OrderID > 0) + select c, + entryCount: 89); } [ConditionalFact] public virtual void Collection_select_nav_prop_all() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { All = c.Orders.All(o => o.CustomerID == "ALFKI") }) - .ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { All = c.Orders.All(o => o.CustomerID == "ALFKI") }, + cs => from c in cs + select new { All = (c.Orders ?? new List()).All(o => o.CustomerID == "ALFKI") }); } [ConditionalFact] public virtual void Collection_select_nav_prop_all_client() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { All = c.Orders.All(o => o.ShipCity == "London") }) - .ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { All = c.Orders.All(o => o.ShipCity == "London") }, + cs => from c in cs + select new { All = (c.Orders ?? new List()).All(o => o.ShipCity == "London") }); } [ConditionalFact] public virtual void Collection_where_nav_prop_all() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where c.Orders.All(o => o.CustomerID == "ALFKI") - select c).ToList(); - - Assert.Equal(3, customers.Count); - } + AssertQuery( + cs => from c in cs + where c.Orders.All(o => o.CustomerID == "ALFKI") + select c, + cs => from c in cs + where (c.Orders ?? new List()).All(o => o.CustomerID == "ALFKI") + select c, + entryCount: 3); } [ConditionalFact] @@ -493,119 +390,101 @@ where c.Orders.All(o => o.ShipCity == "London") [ConditionalFact] public virtual void Collection_select_nav_prop_count() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { c.Orders.Count }).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { c.Orders.Count }, + cs => from c in cs + select new { (c.Orders ?? new List()).Count }); } [ConditionalFact] public virtual void Collection_where_nav_prop_count() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where c.Orders.Count() > 5 - select c).ToList(); - - Assert.Equal(63, customers.Count); - } + AssertQuery( + cs => from c in cs + where c.Orders.Count() > 5 + select c, + cs => from c in cs + where (c.Orders ?? new List()).Count() > 5 + select c, + entryCount: 63); } [ConditionalFact] public virtual void Collection_where_nav_prop_count_reverse() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where 5 < c.Orders.Count() - select c).ToList(); - - Assert.Equal(63, customers.Count); - } + AssertQuery( + cs => from c in cs + where 5 < c.Orders.Count() + select c, + cs => from c in cs + where 5 < (c.Orders ?? new List()).Count() + select c, + entryCount: 63); } [ConditionalFact] public virtual void Collection_orderby_nav_prop_count() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - orderby c.Orders.Count() - select c).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + orderby c.Orders.Count() + select c, + cs => from c in cs + orderby (c.Orders ?? new List()).Count() + select c, + entryCount: 91); } [ConditionalFact] public virtual void Collection_select_nav_prop_long_count() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { C = c.Orders.LongCount() }).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { C = c.Orders.LongCount() }, + cs => from c in cs + select new { C = (c.Orders ?? new List()).LongCount() }); } [ConditionalFact] public virtual void Select_multiple_complex_projections() { - using (var context = CreateContext()) - { - var customers - = (from o in context.Orders - where o.CustomerID.StartsWith("A") - select new - { - collection1 = o.OrderDetails.Count(), - scalar1 = o.OrderDate, - any = o.OrderDetails.Select(od => od.UnitPrice).Any(up => up > 10), - conditional = o.CustomerID == "ALFKI" ? "50" : "10", - scalar2 = (int?)o.OrderID, - all = o.OrderDetails.All(od => od.OrderID == 42), - collection2 = o.OrderDetails.LongCount() - }).ToList(); - - Assert.Equal(30, customers.Count); - } + AssertQuery( + os => from o in os + where o.CustomerID.StartsWith("A") + select new + { + collection1 = o.OrderDetails.Count(), + scalar1 = o.OrderDate, + any = o.OrderDetails.Select(od => od.UnitPrice).Any(up => up > 10), + conditional = o.CustomerID == "ALFKI" ? "50" : "10", + scalar2 = (int?)o.OrderID, + all = o.OrderDetails.All(od => od.OrderID == 42), + collection2 = o.OrderDetails.LongCount(), + }); } [ConditionalFact] public virtual void Collection_select_nav_prop_sum() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { Sum = c.Orders.Sum(o => o.OrderID) }).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { Sum = c.Orders.Sum(o => o.OrderID) }, + cs => from c in cs + select new { Sum = (c.Orders ?? new List()).Sum(o => o.OrderID) }); } [ConditionalFact] public virtual void Collection_where_nav_prop_sum() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - where c.Orders.Sum(o => o.OrderID) > 1000 - select c).ToList(); - - Assert.Equal(89, customers.Count); - } + AssertQuery( + cs => from c in cs + where c.Orders.Sum(o => o.OrderID) > 1000 + select c, + cs => from c in cs + where (c.Orders ?? new List()).Sum(o => o.OrderID) > 1000 + select c, + entryCount: 89); } [ConditionalFact] @@ -625,107 +504,103 @@ where c.Orders.Sum(o => o.OrderID) > 1000 [ConditionalFact] public virtual void Collection_select_nav_prop_first_or_default() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { First = c.Orders.FirstOrDefault() }).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { First = c.Orders.FirstOrDefault() }, + cs => from c in cs + select new { First = (c.Orders ?? new List()).FirstOrDefault() }); } - [ConditionalFact] + // issue #5191 + ////[ConditionalFact] public virtual void Collection_select_nav_prop_first_or_default_then_nav_prop() { - using (var context = CreateContext()) - { - var customers - = (from c in context.Set() - select new { c.Orders.FirstOrDefault().Customer }).ToList(); - - Assert.Equal(91, customers.Count); - } + AssertQuery( + cs => from c in cs + select new { c.Orders.FirstOrDefault().Customer }, + cs => from c in cs + select new { Customer = c.Orders != null ? c.Orders.FirstOrDefault().Customer : null }); } [ConditionalFact] public virtual void Navigation_fk_based_inside_contains() { - using (var context = CreateContext()) - { - var query - = from o in context.Orders + AssertQuery( + os => from o in os where new[] { "ALFKI" }.Contains(o.Customer.CustomerID) - select o; - - var result = query.ToList(); - - Assert.Equal(6, result.Count); - Assert.True(result.All(e => e.CustomerID == "ALFKI")); - } + select o, + entryCount: 6); } [ConditionalFact] public virtual void Navigation_inside_contains() { - using (var context = CreateContext()) - { - var query - = from o in context.Orders + AssertQuery( + os => from o in os where new[] { "Novigrad", "Seattle" }.Contains(o.Customer.City) - select o; - - var result = query.ToList(); - - Assert.Equal(14, result.Count); - } + select o, + entryCount: 14); } [ConditionalFact] public virtual void Navigation_inside_contains_nested() { - using (var context = CreateContext()) - { - var query - = from od in context.OrderDetails - where new[] { "Novigrad", "Seattle" }.Contains(od.Order.Customer.City) - select od; + AssertQuery( + ods => from od in ods + where new[] { "Novigrad", "Seattle" }.Contains(od.Order.Customer.City) + select od, + asserter: (l2oItems, efItems) => + { + var l2oIds = l2oItems.Select(i => new { i.OrderID, i.ProductID }); + var efIds = efItems.Select(i => new { i.OrderID, i.ProductID }); - var result = query.ToList(); + foreach (var efId in efIds) + { + Assert.True(l2oIds.Contains(efId)); + } - Assert.Equal(40, result.Count); - } + foreach (var l2oId in l2oIds) + { + Assert.True(efIds.Contains(l2oId)); + } + }, + entryCount: 40); } [ConditionalFact] public virtual void Navigation_from_join_clause_inside_contains() { - using (var context = CreateContext()) - { - var query = from od in context.OrderDetails - join o in context.Orders on od.OrderID equals o.OrderID - where new[] { "USA", "Redania" }.Contains(o.Customer.Country) - select od; + AssertQuery( + (ods, os) => from od in ods + join o in os on od.OrderID equals o.OrderID + where new[] { "USA", "Redania" }.Contains(o.Customer.Country) + select od, + asserter: (l2oItems, efItems) => + { + var l2oIds = l2oItems.Select(i => new { i.OrderID, i.ProductID }); + var efIds = efItems.Select(i => new { i.OrderID, i.ProductID }); - var result = query.ToList(); + foreach (var efId in efIds) + { + Assert.True(l2oIds.Contains(efId)); + } - Assert.Equal(352, result.Count); - } + foreach (var l2oId in l2oIds) + { + Assert.True(efIds.Contains(l2oId)); + } + }, + entryCount: 352); } [ConditionalFact] public virtual void Where_subquery_on_navigation() { - using (var context = CreateContext()) - { - var query = from p in context.Products - where p.OrderDetails.Contains(context.OrderDetails.FirstOrDefault(orderDetail => orderDetail.Quantity == 1)) - select p; - - var result = query.ToList(); - - Assert.Equal(1, result.Count); - } + AssertQuery( + (ps, ods) => from p in ps + where p.OrderDetails.Contains(ods.FirstOrDefault(orderDetail => orderDetail.Quantity == 1)) + select p, + entryCount: 1); } // issue #4547 @@ -750,16 +625,22 @@ public virtual void Navigation_in_subquery_referencing_outer_query() [ConditionalFact] public virtual void GroupBy_on_nav_prop() { - using (var context = CreateContext()) + AssertQuery>( + os => from o in os + group o by o.Customer.City into og + select og, + asserter: (l2oItems, efItems) => { - var query = from o in context.Orders - group o by o.Customer.City into og - select og; - - var result = query.ToList(); - - Assert.Equal(69, result.Count); - } + foreach (var pair in + from l2oItem in l2oItems + join efItem in efItems on l2oItem.Key equals efItem.Key + select new { l2oItem, efItem }) + { + Assert.Equal( + pair.l2oItem.Select(i => i.OrderID).OrderBy(i => i), + pair.efItem.Select(i => i.OrderID).OrderBy(i => i)); + } + }); } protected QueryNavigationsTestBase(TFixture fixture) @@ -774,5 +655,96 @@ protected QueryNavigationsTestBase(TFixture fixture) protected virtual void ClearLog() { } + + protected void AssertQuery( + Func, IQueryable> query, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem : class + { + AssertQuery(query, query, assertOrder, entryCount, asserter); + } + + protected void AssertQuery( + Func, IQueryable, IQueryable> query, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem1: class + where TItem2: class + { + AssertQuery(query, query, assertOrder, entryCount, asserter); + } + + + protected void AssertQuery( + Func, IQueryable> query, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem : class + { + AssertQuery(query, query, assertOrder, entryCount, asserter); + } + + protected void AssertQuery( + Func, IQueryable> efQuery, + Func, IQueryable> l2oQuery, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem : class + { + using (var context = CreateContext()) + { + TestHelpers.AssertResults( + l2oQuery(NorthwindData.Set()).ToArray(), + efQuery(context.Set()).ToArray(), + assertOrder, + asserter); + + Assert.Equal(entryCount, context.ChangeTracker.Entries().Count()); + } + } + + protected void AssertQuery( + Func, IQueryable> efQuery, + Func, IQueryable> l2oQuery, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem : class + { + using (var context = CreateContext()) + { + TestHelpers.AssertResults( + l2oQuery(NorthwindData.Set()).ToArray(), + efQuery(context.Set()).ToArray(), + assertOrder, + asserter); + } + } + + protected void AssertQuery( + Func, IQueryable, IQueryable> efQuery, + Func, IQueryable, IQueryable> l2oQuery, + bool assertOrder = false, + int entryCount = 0, + Action, IList> asserter = null) + where TItem1 : class + where TItem2 : class + { + using (var context = CreateContext()) + { + TestHelpers.AssertResults( + l2oQuery(NorthwindData.Set(), NorthwindData.Set()).ToArray(), + efQuery(context.Set(), context.Set()).ToArray(), + assertOrder, + asserter); + + Assert.Equal(entryCount, context.ChangeTracker.Entries().Count()); + } + } } } diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/TestModels/Northwind/NorthwindData.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/TestModels/Northwind/NorthwindData.cs index e5f318e1c36..ac1a9f73149 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/TestModels/Northwind/NorthwindData.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/TestModels/Northwind/NorthwindData.cs @@ -18,6 +18,57 @@ namespace Microsoft.EntityFrameworkCore.Specification.Tests.TestModels.Northwind { public static class NorthwindData { + static NorthwindData() + { + _customers = CreateCustomers(); + _employees = CreateEmployees(); + _products = CreateProducts(); + _orders = CreateOrders(); + _orderDetails = CreateOrderDetails(); + + foreach (var order in _orders) + { + var customer = _customers.Where(c => c.CustomerID == order.CustomerID).First(); + order.Customer = customer; + + if (customer.Orders == null) + { + customer.Orders = new List(); + } + + customer.Orders.Add(order); + } + + foreach (var orderDetail in _orderDetails) + { + var order = _orders.Where(o => o.OrderID == orderDetail.OrderID).First(); + var product = _products.Where(p => p.ProductID == orderDetail.ProductID).First(); + orderDetail.Order = order; + orderDetail.Product = product; + + if (order.OrderDetails == null) + { + order.OrderDetails = new List(); + } + + order.OrderDetails.Add(orderDetail); + + if (product.OrderDetails == null) + { + product.OrderDetails = new List(); + } + + product.OrderDetails.Add(orderDetail); + + } + + foreach (var employee in _employees) + { + var manager = _employees.Where(e => employee.ReportsTo == e.EmployeeID).FirstOrDefault(); + employee.Manager = manager; + } + } + public static IQueryable Set() { if (typeof(T) == typeof(Customer)) @@ -138,7 +189,7 @@ private static DateTime ParseDate(string date) #region Customers - private static readonly Customer[] _customers = CreateCustomers(); + private static readonly Customer[] _customers; public static Customer[] CreateCustomers() => new[] @@ -1423,7 +1474,7 @@ public static Customer[] CreateCustomers() #region Employees - private static readonly Employee[] _employees = CreateEmployees(); + private static readonly Employee[] _employees; public static Employee[] CreateEmployees() => new[] @@ -1652,7 +1703,7 @@ public static Employee[] CreateEmployees() #region Products - private static readonly Product[] _products = CreateProducts(); + private static readonly Product[] _products; public static Product[] CreateProducts() => new[] @@ -2664,7 +2715,7 @@ public static Product[] CreateProducts() #region Orders - private static readonly Order[] _orders = CreateOrders(); + private static readonly Order[] _orders; public static Order[] CreateOrders() => new[] @@ -16785,7 +16836,7 @@ public static Order[] CreateOrders() #region OrderDetails - private static readonly OrderDetail[] _orderDetails = CreateOrderDetails(); + private static readonly OrderDetail[] _orderDetails; public static OrderDetail[] CreateOrderDetails() => new[] diff --git a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IIdentityMap.cs b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IIdentityMap.cs index 6b74e2159cd..ab46ab559f1 100644 --- a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IIdentityMap.cs +++ b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IIdentityMap.cs @@ -15,6 +15,8 @@ public interface IIdentityMap bool Contains([NotNull] IForeignKey foreignKey, ValueBuffer valueBuffer); + object TryGetEntryKey(ValueBuffer valueBuffer, bool throwOnNullKey); + InternalEntityEntry TryGetEntry(ValueBuffer valueBuffer, bool throwOnNullKey); InternalEntityEntry TryGetEntry([NotNull] IForeignKey foreignKey, [NotNull] InternalEntityEntry dependentEntry); diff --git a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IStateManager.cs b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IStateManager.cs index 55af26e464b..68788ffdc7a 100644 --- a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IStateManager.cs +++ b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IStateManager.cs @@ -19,6 +19,8 @@ public interface IStateManager void BeginTrackingQuery(); + object TryGetEntryKey([NotNull] IKey key, ValueBuffer valueBuffer, bool throwOnNullKey); + InternalEntityEntry TryGetEntry([NotNull] IKey key, ValueBuffer valueBuffer, bool throwOnNullKey); InternalEntityEntry TryGetEntry([NotNull] object entity); diff --git a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IdentityMap.cs b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IdentityMap.cs index b5a9bb0a592..c1e64778479 100644 --- a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IdentityMap.cs +++ b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/IdentityMap.cs @@ -43,6 +43,7 @@ public IdentityMap( public virtual bool Contains(ValueBuffer valueBuffer) { var key = PrincipalKeyValueFactory.CreateFromBuffer(valueBuffer); + return key != null && _identityMap.ContainsKey((TKey)key); } @@ -53,6 +54,18 @@ public virtual bool Contains(IForeignKey foreignKey, ValueBuffer valueBuffer) && _identityMap.ContainsKey(key); } + public virtual object TryGetEntryKey(ValueBuffer valueBuffer, bool throwOnNullKey) + { + var key = PrincipalKeyValueFactory.CreateFromBuffer(valueBuffer); + if (key == null + && throwOnNullKey) + { + throw new InvalidOperationException(CoreStrings.InvalidKeyValue(Key.DeclaringEntityType.DisplayName())); + } + + return key; + } + public virtual InternalEntityEntry TryGetEntry(ValueBuffer valueBuffer, bool throwOnNullKey) { InternalEntityEntry entry; @@ -62,6 +75,7 @@ public virtual InternalEntityEntry TryGetEntry(ValueBuffer valueBuffer, bool thr { throw new InvalidOperationException(CoreStrings.InvalidKeyValue(Key.DeclaringEntityType.DisplayName())); } + return key != null && _identityMap.TryGetValue((TKey)key, out entry) ? entry : null; } @@ -79,6 +93,7 @@ public virtual InternalEntityEntry TryGetEntryUsingRelationshipSnapshot(IForeign { TKey key; InternalEntityEntry entry; + return foreignKey.GetDependentKeyValueFactory().TryCreateFromRelationshipSnapshot(dependentEntry, out key) && _identityMap.TryGetValue(key, out entry) ? entry diff --git a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/StateManager.cs b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/StateManager.cs index e623742ba3f..531c3e3b228 100644 --- a/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/StateManager.cs +++ b/src/Microsoft.EntityFrameworkCore/ChangeTracking/Internal/StateManager.cs @@ -116,6 +116,9 @@ public virtual InternalEntityEntry StartTrackingFromQuery( return newEntry; } + public virtual object TryGetEntryKey(IKey key, ValueBuffer valueBuffer, bool throwOnNullKey) + => GetOrCreateIdentityMap(key).TryGetEntryKey(valueBuffer, throwOnNullKey); + public virtual InternalEntityEntry TryGetEntry(IKey key, ValueBuffer valueBuffer, bool throwOnNullKey) => GetOrCreateIdentityMap(key).TryGetEntry(valueBuffer, throwOnNullKey); diff --git a/src/Microsoft.EntityFrameworkCore/Query/Internal/IQueryBuffer.cs b/src/Microsoft.EntityFrameworkCore/Query/Internal/IQueryBuffer.cs index 3fa29bd9898..2774e6e7bd3 100644 --- a/src/Microsoft.EntityFrameworkCore/Query/Internal/IQueryBuffer.cs +++ b/src/Microsoft.EntityFrameworkCore/Query/Internal/IQueryBuffer.cs @@ -11,6 +11,12 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal { public interface IQueryBuffer { + object GetEntityKey( + [NotNull] IKey key, + EntityLoadInfo entityLoadInfo, + bool queryStateManager, + bool throwOnNullKey); + object GetEntity( [NotNull] IKey key, EntityLoadInfo entityLoadInfo, diff --git a/src/Microsoft.EntityFrameworkCore/Query/Internal/QueryBuffer.cs b/src/Microsoft.EntityFrameworkCore/Query/Internal/QueryBuffer.cs index fa948f94714..11f6adee018 100644 --- a/src/Microsoft.EntityFrameworkCore/Query/Internal/QueryBuffer.cs +++ b/src/Microsoft.EntityFrameworkCore/Query/Internal/QueryBuffer.cs @@ -37,6 +37,12 @@ public QueryBuffer( _changeDetector = changeDetector; } + public virtual object GetEntityKey( + IKey key, EntityLoadInfo entityLoadInfo, bool queryStateManager, bool throwOnNullKey) + { + return _stateManager.TryGetEntryKey(key, entityLoadInfo.ValueBuffer, throwOnNullKey); + } + public virtual object GetEntity( IKey key, EntityLoadInfo entityLoadInfo, bool queryStateManager, bool throwOnNullKey) { diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryNavigationsSqlServerTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryNavigationsSqlServerTest.cs index 48cb1f8a9ef..f1aca4706cd 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryNavigationsSqlServerTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/QueryNavigationsSqlServerTest.cs @@ -32,7 +32,7 @@ public override void Select_Where_Navigation_Deep() FROM [Order Details] AS [od] INNER JOIN [Orders] AS [od.Order] ON [od].[OrderID] = [od.Order].[OrderID] LEFT JOIN [Customers] AS [od.Order.Customer] ON [od.Order].[CustomerID] = [od.Order.Customer].[CustomerID] -ORDER BY [od.Order].[CustomerID]", +ORDER BY [od].[OrderID], [od].[ProductID], [od.Order].[CustomerID]", Sql); } @@ -73,18 +73,6 @@ FROM [Orders] AS [o] Sql); } - public override void Select_Where_Navigations() - { - base.Select_Where_Navigations(); - - Assert.Equal( - @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate], [o.Customer].[CustomerID], [o.Customer].[Address], [o.Customer].[City], [o.Customer].[CompanyName], [o.Customer].[ContactName], [o.Customer].[ContactTitle], [o.Customer].[Country], [o.Customer].[Fax], [o.Customer].[Phone], [o.Customer].[PostalCode], [o.Customer].[Region] -FROM [Orders] AS [o] -LEFT JOIN [Customers] AS [o.Customer] ON [o].[CustomerID] = [o.Customer].[CustomerID] -ORDER BY [o].[CustomerID]", - Sql); - } - public override void Select_Where_Navigation_Multiple_Access() { base.Select_Where_Navigation_Multiple_Access(); diff --git a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs index 54c9a6e72d2..e624409706d 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs @@ -300,6 +300,11 @@ public void BeginTrackingQuery() throw new NotImplementedException(); } + public object TryGetEntryKey(IKey key, ValueBuffer valueBuffer, bool throwOnNullKey) + { + throw new NotImplementedException(); + } + public InternalEntityEntry TryGetEntry(IKey key, ValueBuffer valueBuffer, bool throwOnNullKey) { throw new NotImplementedException();