Skip to content

Commit

Permalink
Query: Convert Entity comparison with null to comparison on key prope…
Browse files Browse the repository at this point in the history
…rties

Safe-guard ValueBuffer.get_item calls when value buffer could be null
  • Loading branch information
smitpatel committed Oct 10, 2016
1 parent 7aa7432 commit ab805ab
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ private static Expression Invert(Expression test)
return null;
}

private static Expression TryRemoveNullCheck(ConditionalExpression node)
private Expression TryRemoveNullCheck(ConditionalExpression node)
{
var binaryTest = node.Test as BinaryExpression;

Expand Down Expand Up @@ -340,7 +340,7 @@ private static Expression TryRemoveNullCheck(ConditionalExpression node)
var testExpression = isLeftNullConstant ? binaryTest.Right : binaryTest.Left;
var resultExpression = binaryTest.NodeType == ExpressionType.Equal ? node.IfFalse : node.IfTrue;

var nullCheckRemovalTestingVisitor = new NullCheckRemovalTestingVisitor();
var nullCheckRemovalTestingVisitor = new NullCheckRemovalTestingVisitor(_queryModelVisitor.QueryCompilationContext.Model);

return nullCheckRemovalTestingVisitor.CanRemoveNullCheck(testExpression, resultExpression)
? resultExpression
Expand All @@ -350,9 +350,15 @@ private static Expression TryRemoveNullCheck(ConditionalExpression node)
private class NullCheckRemovalTestingVisitor : ExpressionVisitorBase
{
private IQuerySource _querySource;
private IModel _model;
private string _propertyName;
private bool? _canRemoveNullCheck;

public NullCheckRemovalTestingVisitor(IModel model)
{
_model = model;
}

public bool CanRemoveNullCheck(Expression testExpression, Expression resultExpression)
{
AnalyzeTestExpression(testExpression);
Expand Down Expand Up @@ -404,6 +410,10 @@ private void AnalyzeTestExpression(Expression expression)
{
_querySource = querySourceCaller.ReferencedQuerySource;
_propertyName = (string)propertyNameExpression.Value;
if (_model.FindEntityType(_querySource.ItemType)?.FindProperty(_propertyName)?.IsPrimaryKey() ?? false)
{
_propertyName = null;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,32 @@ join l2_Required_Reverse in context.LevelOne on l2.Level1_Required_Id equals l2_
[ConditionalFact]
public virtual void Query_source_materialization_bug_4547()
{
List<int> expected;
using (var context = CreateContext())
{
expected = (from e3 in context.LevelThree.ToList()
join e1 in context.LevelOne.ToList()
on
(int?)e3.Id
equals
(
from subQuery2 in context.LevelTwo.ToList()
join subQuery3 in context.LevelThree.ToList()
on
subQuery2 != null ? (int?)subQuery2.Id : null
equals
subQuery3.Level2_Optional_Id
into
grouping
from subQuery3 in grouping.DefaultIfEmpty()
select subQuery3 != null ? (int?)subQuery3.Id : null
).FirstOrDefault()
select e1.Id).ToList();

}

ClearLog();

using (var context = CreateContext())
{
var query = from e3 in context.LevelThree
Expand All @@ -2610,6 +2636,8 @@ from subQuery3 in grouping.DefaultIfEmpty()
select e1.Id;

var result = query.ToList();

Assert.Equal(expected, result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ from c in cs
select c.CustomerID);
}

[ConditionalFact]
public virtual void Entity_equality_null()
{
AssertQuery<Customer>(cs =>
from c in cs
where c == null
select c.CustomerID);
}

[ConditionalFact]
public virtual void Entity_equality_not_null()
{
AssertQuery<Customer>(cs =>
from c in cs
where c != null
select c.CustomerID);
}

[ConditionalFact]
public virtual void Null_conditional_simple()
{
Expand Down Expand Up @@ -6320,6 +6338,27 @@ public virtual void DefaultIfEmpty_without_group_join()
}
}

[ConditionalFact]
public virtual void DefaultIfEmpty_in_subquery()
{
AssertQuery<Customer, Order>((cs, os) =>
(from c in cs
from o in os.Where(o => o.CustomerID == c.CustomerID).DefaultIfEmpty()
where o != null
select new { c.CustomerID, o.OrderID }));
}

[ConditionalFact]
public virtual void DefaultIfEmpty_in_subquery_nested()
{
AssertQuery<Customer, Order>((cs, os) =>
(from c in cs.Where(c => c.City == "Seattle")
from o1 in os.Where(o => o.OrderID > 11000).DefaultIfEmpty()
from o2 in os.Where(o => o.CustomerID == c.CustomerID).DefaultIfEmpty()
where o1 != null && o2 != null
select new { c.CustomerID, o1.OrderID, o2.OrderDate }));
}

protected NorthwindContext CreateContext() => Fixture.CreateContext();

protected QueryTestBase(TFixture fixture)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -58,32 +59,38 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
|| binaryExpression.NodeType == ExpressionType.NotEqual)
{
var constantExpression = newBinaryExpression.Left.RemoveConvert() as ConstantExpression;
var isLeftNullConstant = constantExpression != null && constantExpression.Value == null;

if (constantExpression != null
&& constantExpression.Value == null)
constantExpression = newBinaryExpression.Right.RemoveConvert() as ConstantExpression;
var isRightNullConstant = constantExpression != null && constantExpression.Value == null;

if (isLeftNullConstant && isRightNullConstant)
{
return newBinaryExpression;
}

constantExpression = newBinaryExpression.Right.RemoveConvert() as ConstantExpression;

if (constantExpression != null
&& constantExpression.Value == null)
var isNullComparison = isLeftNullConstant || isRightNullConstant;
var nonNullExpression = isLeftNullConstant ? newBinaryExpression.Right : newBinaryExpression.Left;
// If a navigation being compared to null then don't rewrite
if (isNullComparison
&& !(nonNullExpression is QuerySourceReferenceExpression))
{
return newBinaryExpression;
}

var entityType = _model.FindEntityType(newBinaryExpression.Left.Type);
var entityType = _model.FindEntityType(nonNullExpression.Type);

if (entityType != null)
{
var primaryKeyProperties = entityType.FindPrimaryKey().Properties;

var newLeftExpression
= CreateKeyAccessExpression(newBinaryExpression.Left, primaryKeyProperties);
var newLeftExpression = isLeftNullConstant
? Expression.Constant(null, typeof(object))
: CreateKeyAccessExpression(newBinaryExpression.Left, primaryKeyProperties, isNullComparison);

var newRightExpression
= CreateKeyAccessExpression(newBinaryExpression.Right, primaryKeyProperties);
var newRightExpression = isRightNullConstant
? Expression.Constant(null, typeof(object))
: CreateKeyAccessExpression(newBinaryExpression.Right, primaryKeyProperties, isNullComparison);

return Expression.MakeBinary(newBinaryExpression.NodeType, newLeftExpression, newRightExpression);
}
Expand All @@ -92,9 +99,10 @@ var newRightExpression
return newBinaryExpression;
}

private static Expression CreateKeyAccessExpression(Expression target, IReadOnlyList<IProperty> properties)
private static Expression CreateKeyAccessExpression(Expression target, IReadOnlyList<IProperty> properties, bool nullComparison)
{
return properties.Count == 1
// If comparing with null then we need only first PK property
return (properties.Count == 1) || nullComparison
? EntityQueryModelVisitor.CreatePropertyExpression(target, properties[0])
: Expression.New(
CompositeKey.CompositeKeyCtor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,38 +68,52 @@ protected override Expression VisitNew(NewExpression expression)
/// </summary>
protected override Expression VisitBinary(BinaryExpression node)
{
var newLeft = Visit(node.Left);
var newRight = Visit(node.Right);

if (newLeft.Type == typeof(ValueBuffer))

var leftConstantExpression = node.Left.RemoveConvert() as ConstantExpression;
var isLeftNullConstant = leftConstantExpression != null && leftConstantExpression.Value == null;

var rightConstantExpression = node.Right.RemoveConvert() as ConstantExpression;
var isRightNullConstant = rightConstantExpression != null && rightConstantExpression.Value == null;

if (isLeftNullConstant || isRightNullConstant)
{
if (node.NodeType == ExpressionType.Equal
|| node.NodeType == ExpressionType.NotEqual)
var nonNullExpression = isLeftNullConstant ? node.Right : node.Left;

var methodCallExpression = nonNullExpression as MethodCallExpression;
if (methodCallExpression != null)
{
var rightConstantExpression = newRight as ConstantExpression;
if (rightConstantExpression != null
&& rightConstantExpression.Value == null)
if (EntityQueryModelVisitor.IsPropertyMethod(methodCallExpression.Method))
{
return ValueBufferNullCheck(newLeft, node.NodeType == ExpressionType.Equal);
var firstArgument = methodCallExpression.Arguments[0];
var visitedArgument = Visit(firstArgument);
if (visitedArgument.Type == typeof(ValueBuffer))
{
var nullCheck = ValueBufferNullComparisonCheck(visitedArgument);
var propertyAccessExpression = Visit(nonNullExpression);

return Expression.MakeBinary(
node.NodeType,
Expression.Condition(
nullCheck,
propertyAccessExpression,
Expression.Constant(null, propertyAccessExpression.Type)),
Expression.Constant(null));
}
}
}
}

var newLeft = Visit(node.Left);
var newRight = Visit(node.Right);

if (newLeft.Type == typeof(ValueBuffer))
{
newLeft = _queryModelVisitor.BindReadValueMethod(node.Left.Type, newLeft, 0);
}

if (newRight.Type == typeof(ValueBuffer))
{
if (node.NodeType == ExpressionType.Equal
|| node.NodeType == ExpressionType.NotEqual)
{
var leftConstantExpression = newLeft as ConstantExpression;
if (leftConstantExpression != null
&& leftConstantExpression.Value == null)
{
return ValueBufferNullCheck(newRight, node.NodeType == ExpressionType.Equal);
}
}

newRight = _queryModelVisitor.BindReadValueMethod(node.Right.Type, newRight, 0);
}

Expand All @@ -108,16 +122,10 @@ protected override Expression VisitBinary(BinaryExpression node)
return node.Update(newLeft, newConversion, newRight);
}

private static Expression ValueBufferNullCheck(Expression valueBufferExpression, bool equality)
{
var equalsMethod = typeof(ValueBuffer).GetRuntimeMethod(nameof(ValueBuffer.Equals), new[] { typeof(object) });
var equalsExpression = Expression.Call(
private Expression ValueBufferNullComparisonCheck(Expression valueBufferExpression) => Expression.Not(
Expression.MakeMemberAccess(
valueBufferExpression,
equalsMethod,
Expression.Constant(null, typeof(object)));

return equality ? (Expression)equalsExpression : Expression.Not(equalsExpression);
}
typeof(ValueBuffer).GetRuntimeProperty(nameof(ValueBuffer.IsEmpty))));

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ public override void Query_source_materialization_bug_4547()
Sql);

Assert.Contains(
@"SELECT TOP(1) [subQuery2].[Id], [subQuery2].[Date], [subQuery2].[Level1_Optional_Id], [subQuery2].[Level1_Required_Id], [subQuery2].[Name], [subQuery2].[OneToMany_Optional_InverseId], [subQuery2].[OneToMany_Optional_Self_InverseId], [subQuery2].[OneToMany_Required_InverseId], [subQuery2].[OneToMany_Required_Self_InverseId], [subQuery2].[OneToOne_Optional_PK_InverseId], [subQuery2].[OneToOne_Optional_SelfId], [subQuery3].[Id], [subQuery3].[Level2_Optional_Id], [subQuery3].[Level2_Required_Id], [subQuery3].[Name], [subQuery3].[OneToMany_Optional_InverseId], [subQuery3].[OneToMany_Optional_Self_InverseId], [subQuery3].[OneToMany_Required_InverseId], [subQuery3].[OneToMany_Required_Self_InverseId], [subQuery3].[OneToOne_Optional_PK_InverseId], [subQuery3].[OneToOne_Optional_SelfId]
@"SELECT TOP(1) [subQuery3].[Id], [subQuery3].[Level2_Optional_Id], [subQuery3].[Level2_Required_Id], [subQuery3].[Name], [subQuery3].[OneToMany_Optional_InverseId], [subQuery3].[OneToMany_Optional_Self_InverseId], [subQuery3].[OneToMany_Required_InverseId], [subQuery3].[OneToMany_Required_Self_InverseId], [subQuery3].[OneToOne_Optional_PK_InverseId], [subQuery3].[OneToOne_Optional_SelfId]
FROM [Level2] AS [subQuery2]
LEFT JOIN [Level3] AS [subQuery3] ON [subQuery2].[Id] = [subQuery3].[Level2_Optional_Id]
ORDER BY [subQuery2].[Id]",
Expand Down
Loading

0 comments on commit ab805ab

Please sign in to comment.