Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Convert Entity comparison with null to comparison on key properties #6730

Merged
merged 1 commit into from
Oct 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ private static Expression Invert(Expression test)
return null;
}

private static Expression TryRemoveNullCheck(ConditionalExpression node)
private Expression TryRemoveNullCheck(ConditionalExpression node)
{
var binaryTest = node.Test as BinaryExpression;

Expand Down Expand Up @@ -340,7 +340,7 @@ private static Expression TryRemoveNullCheck(ConditionalExpression node)
var testExpression = isLeftNullConstant ? binaryTest.Right : binaryTest.Left;
var resultExpression = binaryTest.NodeType == ExpressionType.Equal ? node.IfFalse : node.IfTrue;

var nullCheckRemovalTestingVisitor = new NullCheckRemovalTestingVisitor();
var nullCheckRemovalTestingVisitor = new NullCheckRemovalTestingVisitor(_queryModelVisitor.QueryCompilationContext.Model);

return nullCheckRemovalTestingVisitor.CanRemoveNullCheck(testExpression, resultExpression)
? resultExpression
Expand All @@ -350,9 +350,15 @@ private static Expression TryRemoveNullCheck(ConditionalExpression node)
private class NullCheckRemovalTestingVisitor : ExpressionVisitorBase
{
private IQuerySource _querySource;
private IModel _model;
private string _propertyName;
private bool? _canRemoveNullCheck;

public NullCheckRemovalTestingVisitor(IModel model)
{
_model = model;
}

public bool CanRemoveNullCheck(Expression testExpression, Expression resultExpression)
{
AnalyzeTestExpression(testExpression);
Expand Down Expand Up @@ -404,6 +410,10 @@ private void AnalyzeTestExpression(Expression expression)
{
_querySource = querySourceCaller.ReferencedQuerySource;
_propertyName = (string)propertyNameExpression.Value;
if (_model.FindEntityType(_querySource.ItemType)?.FindProperty(_propertyName)?.IsPrimaryKey() ?? false)
{
_propertyName = null;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,32 @@ join l2_Required_Reverse in context.LevelOne on l2.Level1_Required_Id equals l2_
[ConditionalFact]
public virtual void Query_source_materialization_bug_4547()
{
List<int> expected;
using (var context = CreateContext())
{
expected = (from e3 in context.LevelThree.ToList()
join e1 in context.LevelOne.ToList()
on
(int?)e3.Id
equals
(
from subQuery2 in context.LevelTwo.ToList()
join subQuery3 in context.LevelThree.ToList()
on
subQuery2 != null ? (int?)subQuery2.Id : null
equals
subQuery3.Level2_Optional_Id
into
grouping
from subQuery3 in grouping.DefaultIfEmpty()
select subQuery3 != null ? (int?)subQuery3.Id : null
).FirstOrDefault()
select e1.Id).ToList();

}

ClearLog();

using (var context = CreateContext())
{
var query = from e3 in context.LevelThree
Expand All @@ -2610,6 +2636,8 @@ from subQuery3 in grouping.DefaultIfEmpty()
select e1.Id;

var result = query.ToList();

Assert.Equal(expected, result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ from c in cs
select c.CustomerID);
}

[ConditionalFact]
public virtual void Entity_equality_null()
{
AssertQuery<Customer>(cs =>
from c in cs
where c == null
select c.CustomerID);
}

[ConditionalFact]
public virtual void Entity_equality_not_null()
{
AssertQuery<Customer>(cs =>
from c in cs
where c != null
select c.CustomerID);
}

[ConditionalFact]
public virtual void Null_conditional_simple()
{
Expand Down Expand Up @@ -6332,6 +6350,27 @@ public virtual void DefaultIfEmpty_without_group_join()
}
}

[ConditionalFact]
public virtual void DefaultIfEmpty_in_subquery()
{
AssertQuery<Customer, Order>((cs, os) =>
(from c in cs
from o in os.Where(o => o.CustomerID == c.CustomerID).DefaultIfEmpty()
where o != null
select new { c.CustomerID, o.OrderID }));
}

[ConditionalFact]
public virtual void DefaultIfEmpty_in_subquery_nested()
{
AssertQuery<Customer, Order>((cs, os) =>
(from c in cs.Where(c => c.City == "Seattle")
from o1 in os.Where(o => o.OrderID > 11000).DefaultIfEmpty()
from o2 in os.Where(o => o.CustomerID == c.CustomerID).DefaultIfEmpty()
where o1 != null && o2 != null
select new { c.CustomerID, o1.OrderID, o2.OrderDate }));
}

protected NorthwindContext CreateContext() => Fixture.CreateContext();

protected QueryTestBase(TFixture fixture)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -58,32 +59,38 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
|| binaryExpression.NodeType == ExpressionType.NotEqual)
{
var constantExpression = newBinaryExpression.Left.RemoveConvert() as ConstantExpression;
var isLeftNullConstant = constantExpression != null && constantExpression.Value == null;

if (constantExpression != null
&& constantExpression.Value == null)
constantExpression = newBinaryExpression.Right.RemoveConvert() as ConstantExpression;
var isRightNullConstant = constantExpression != null && constantExpression.Value == null;

if (isLeftNullConstant && isRightNullConstant)
{
return newBinaryExpression;
}

constantExpression = newBinaryExpression.Right.RemoveConvert() as ConstantExpression;

if (constantExpression != null
&& constantExpression.Value == null)
var isNullComparison = isLeftNullConstant || isRightNullConstant;
var nonNullExpression = isLeftNullConstant ? newBinaryExpression.Right : newBinaryExpression.Left;
// If a navigation being compared to null then don't rewrite
if (isNullComparison
&& !(nonNullExpression is QuerySourceReferenceExpression))
{
return newBinaryExpression;
}

var entityType = _model.FindEntityType(newBinaryExpression.Left.Type);
var entityType = _model.FindEntityType(nonNullExpression.Type);

if (entityType != null)
{
var primaryKeyProperties = entityType.FindPrimaryKey().Properties;

var newLeftExpression
= CreateKeyAccessExpression(newBinaryExpression.Left, primaryKeyProperties);
var newLeftExpression = isLeftNullConstant
? Expression.Constant(null, typeof(object))
: CreateKeyAccessExpression(newBinaryExpression.Left, primaryKeyProperties, isNullComparison);

var newRightExpression
= CreateKeyAccessExpression(newBinaryExpression.Right, primaryKeyProperties);
var newRightExpression = isRightNullConstant
? Expression.Constant(null, typeof(object))
: CreateKeyAccessExpression(newBinaryExpression.Right, primaryKeyProperties, isNullComparison);

return Expression.MakeBinary(newBinaryExpression.NodeType, newLeftExpression, newRightExpression);
}
Expand All @@ -92,9 +99,10 @@ var newRightExpression
return newBinaryExpression;
}

private static Expression CreateKeyAccessExpression(Expression target, IReadOnlyList<IProperty> properties)
private static Expression CreateKeyAccessExpression(Expression target, IReadOnlyList<IProperty> properties, bool nullComparison)
{
return properties.Count == 1
// If comparing with null then we need only first PK property
return (properties.Count == 1) || nullComparison
? EntityQueryModelVisitor.CreatePropertyExpression(target, properties[0])
: Expression.New(
CompositeKey.CompositeKeyCtor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,38 +68,52 @@ protected override Expression VisitNew(NewExpression expression)
/// </summary>
protected override Expression VisitBinary(BinaryExpression node)
{
var newLeft = Visit(node.Left);
var newRight = Visit(node.Right);

if (newLeft.Type == typeof(ValueBuffer))

var leftConstantExpression = node.Left.RemoveConvert() as ConstantExpression;
var isLeftNullConstant = leftConstantExpression != null && leftConstantExpression.Value == null;

var rightConstantExpression = node.Right.RemoveConvert() as ConstantExpression;
var isRightNullConstant = rightConstantExpression != null && rightConstantExpression.Value == null;

if (isLeftNullConstant || isRightNullConstant)
{
if (node.NodeType == ExpressionType.Equal
|| node.NodeType == ExpressionType.NotEqual)
var nonNullExpression = isLeftNullConstant ? node.Right : node.Left;

var methodCallExpression = nonNullExpression as MethodCallExpression;
if (methodCallExpression != null)
{
var rightConstantExpression = newRight as ConstantExpression;
if (rightConstantExpression != null
&& rightConstantExpression.Value == null)
if (EntityQueryModelVisitor.IsPropertyMethod(methodCallExpression.Method))
{
return ValueBufferNullCheck(newLeft, node.NodeType == ExpressionType.Equal);
var firstArgument = methodCallExpression.Arguments[0];
var visitedArgument = Visit(firstArgument);
if (visitedArgument.Type == typeof(ValueBuffer))
{
var nullCheck = ValueBufferNullComparisonCheck(visitedArgument);
var propertyAccessExpression = Visit(nonNullExpression);

return Expression.MakeBinary(
node.NodeType,
Expression.Condition(
nullCheck,
propertyAccessExpression,
Expression.Constant(null, propertyAccessExpression.Type)),
Expression.Constant(null));
}
}
}
}

var newLeft = Visit(node.Left);
var newRight = Visit(node.Right);

if (newLeft.Type == typeof(ValueBuffer))
{
newLeft = _queryModelVisitor.BindReadValueMethod(node.Left.Type, newLeft, 0);
}

if (newRight.Type == typeof(ValueBuffer))
{
if (node.NodeType == ExpressionType.Equal
|| node.NodeType == ExpressionType.NotEqual)
{
var leftConstantExpression = newLeft as ConstantExpression;
if (leftConstantExpression != null
&& leftConstantExpression.Value == null)
{
return ValueBufferNullCheck(newRight, node.NodeType == ExpressionType.Equal);
}
}

newRight = _queryModelVisitor.BindReadValueMethod(node.Right.Type, newRight, 0);
}

Expand All @@ -108,16 +122,10 @@ protected override Expression VisitBinary(BinaryExpression node)
return node.Update(newLeft, newConversion, newRight);
}

private static Expression ValueBufferNullCheck(Expression valueBufferExpression, bool equality)
{
var equalsMethod = typeof(ValueBuffer).GetRuntimeMethod(nameof(ValueBuffer.Equals), new[] { typeof(object) });
var equalsExpression = Expression.Call(
private Expression ValueBufferNullComparisonCheck(Expression valueBufferExpression) => Expression.Not(
Expression.MakeMemberAccess(
valueBufferExpression,
equalsMethod,
Expression.Constant(null, typeof(object)));

return equality ? (Expression)equalsExpression : Expression.Not(equalsExpression);
}
typeof(ValueBuffer).GetRuntimeProperty(nameof(ValueBuffer.IsEmpty))));

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ public override void Query_source_materialization_bug_4547()
Sql);

Assert.Contains(
@"SELECT TOP(1) [subQuery2].[Id], [subQuery2].[Date], [subQuery2].[Level1_Optional_Id], [subQuery2].[Level1_Required_Id], [subQuery2].[Name], [subQuery2].[OneToMany_Optional_InverseId], [subQuery2].[OneToMany_Optional_Self_InverseId], [subQuery2].[OneToMany_Required_InverseId], [subQuery2].[OneToMany_Required_Self_InverseId], [subQuery2].[OneToOne_Optional_PK_InverseId], [subQuery2].[OneToOne_Optional_SelfId], [subQuery3].[Id], [subQuery3].[Level2_Optional_Id], [subQuery3].[Level2_Required_Id], [subQuery3].[Name], [subQuery3].[OneToMany_Optional_InverseId], [subQuery3].[OneToMany_Optional_Self_InverseId], [subQuery3].[OneToMany_Required_InverseId], [subQuery3].[OneToMany_Required_Self_InverseId], [subQuery3].[OneToOne_Optional_PK_InverseId], [subQuery3].[OneToOne_Optional_SelfId]
@"SELECT TOP(1) [subQuery3].[Id], [subQuery3].[Level2_Optional_Id], [subQuery3].[Level2_Required_Id], [subQuery3].[Name], [subQuery3].[OneToMany_Optional_InverseId], [subQuery3].[OneToMany_Optional_Self_InverseId], [subQuery3].[OneToMany_Required_InverseId], [subQuery3].[OneToMany_Required_Self_InverseId], [subQuery3].[OneToOne_Optional_PK_InverseId], [subQuery3].[OneToOne_Optional_SelfId]
FROM [Level2] AS [subQuery2]
LEFT JOIN [Level3] AS [subQuery3] ON [subQuery2].[Id] = [subQuery3].[Level2_Optional_Id]
ORDER BY [subQuery2].[Id]",
Expand Down
Loading