Skip to content

Commit

Permalink
ExecuteUpdate: Convert to join for query with unsupported operations
Browse files Browse the repository at this point in the history
Resolves #28661
  • Loading branch information
smitpatel committed Aug 16, 2022
1 parent 0577a3a commit 20d353f
Show file tree
Hide file tree
Showing 4 changed files with 760 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

EntityShaperExpression? entityShaperExpression = null;
var setColumnValues = new List<SetColumnValue>();
foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions)
{
var left = RemapLambdaBody(source, propertyExpression);
Expand All @@ -1148,28 +1147,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
entityShaperExpression.EntityType.DisplayName(), ese.EntityType.DisplayName()));
return null;
}

var right = RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating sothat value infer tye type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
var setter = Infrastructure.ExpressionExtensions.CreateEqualsExpression(left, right);
var translation = _sqlTranslator.Translate(setter);
if (translation is SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: ColumnExpression column } sqlBinaryExpression)
{
setColumnValues.Add(new SetColumnValue(column, sqlBinaryExpression.Right));
}
else
{
// We would reach here only if the property is unmapped or value fails to translate.
AddTranslationErrorDetails(RelationalStrings.UnableToTranslateSetProperty(
propertyExpression.Print(), valueExpression.Print(), _sqlTranslator.TranslationErrorDetails));
return null;
}
}

Check.DebugAssert(entityShaperExpression != null, "EntityShaperExpression should have a value.");
Expand Down Expand Up @@ -1203,10 +1180,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var selectExpression = (SelectExpression)source.QueryExpression;
if (IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out var tableExpression))
{
selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();

return new NonQueryExpression(new UpdateExpression(tableExpression, selectExpression, setColumnValues));
return TranslateSetPropertyExpressions(this, source, selectExpression, tableExpression, propertyValueLambdaExpressions);
}

// We need to convert to join with original query using PK
Expand All @@ -1220,31 +1194,85 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return null;
}

//var clrType = entityType.ClrType;
//var entityParameter = Expression.Parameter(clrType);
//Expression predicateBody;
//if (pk.Properties.Count == 1)
//{
// predicateBody = Expression.Call(
// QueryableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter);
//}
//else
//{
// var innerParameter = Expression.Parameter(clrType);
// predicateBody = Expression.Call(
// QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
// source,
// Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter)));
//}

//var newSource = Expression.Call(
// QueryableMethods.Where.MakeGenericMethod(clrType),
// new EntityQueryRootExpression(entityType),
// Expression.Quote(Expression.Lambda(predicateBody, entityParameter)));

//return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource));
var outer = (ShapedQueryExpression)Visit(new EntityQueryRootExpression(entityType));
var inner = source;
var outerParameter = Expression.Parameter(entityType.ClrType);
var outerKeySelector = Expression.Lambda(outerParameter.CreateKeyValuesExpression(pk.Properties), outerParameter);
var firstPropertyLambdaExpression = propertyValueLambdaExpressions[0].Item1;
var entitySource = GetEntitySource(firstPropertyLambdaExpression.Body);
var innerKeySelector = Expression.Lambda(
entitySource.CreateKeyValuesExpression(pk.Properties), firstPropertyLambdaExpression.Parameters);

var joinPredicate = CreateJoinPredicate(outer, outerKeySelector, inner, innerKeySelector);

Check.DebugAssert(joinPredicate != null, "Join predicate shouldn't be null");

var outerSelectExpression = (SelectExpression)outer.QueryExpression;
var outerShaperExpression = outerSelectExpression.AddInnerJoin(inner, joinPredicate, outer.ShaperExpression);
outer = outer.UpdateShaperExpression(outerShaperExpression);
var transparentIdentifierType = outer.ShaperExpression.Type;
var transparentIdentifierParameter = Expression.Parameter(transparentIdentifierType);

var propertyReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer");
var valueReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner");
for (var i = 0; i < propertyValueLambdaExpressions.Count; i++)
{
var (propertyExpression, valueExpression) = propertyValueLambdaExpressions[i];
propertyExpression = Expression.Lambda(
ReplacingExpressionVisitor.Replace(propertyExpression.Parameters[0], propertyReplacement, propertyExpression.Body),
transparentIdentifierParameter);
valueExpression = Expression.Lambda(
ReplacingExpressionVisitor.Replace(valueExpression.Parameters[0], valueReplacement, valueExpression.Body),
transparentIdentifierParameter);
propertyValueLambdaExpressions[i] = (propertyExpression, valueExpression);
}

tableExpression = (TableExpression)outerSelectExpression.Tables[0];

return TranslateSetPropertyExpressions(this, outer, outerSelectExpression, tableExpression, propertyValueLambdaExpressions);

static NonQueryExpression? TranslateSetPropertyExpressions(
RelationalQueryableMethodTranslatingExpressionVisitor visitor,
ShapedQueryExpression source,
SelectExpression selectExpression,
TableExpression tableExpression,
List<(LambdaExpression, LambdaExpression)> propertyValueLambdaExpressions)
{
var setColumnValues = new List<SetColumnValue>();
foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions)
{
var left = visitor.RemapLambdaBody(source, propertyExpression);
left = left.UnwrapTypeConversion(out _);
var right = visitor.RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating sothat value infer tye type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
var setter = Infrastructure.ExpressionExtensions.CreateEqualsExpression(left, right);
var translation = visitor._sqlTranslator.Translate(setter);
if (translation is SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: ColumnExpression column } sqlBinaryExpression)
{
setColumnValues.Add(new SetColumnValue(column, sqlBinaryExpression.Right));
}
else
{
// We would reach here only if the property is unmapped or value fails to translate.
visitor.AddTranslationErrorDetails(RelationalStrings.UnableToTranslateSetProperty(
propertyExpression.Print(), valueExpression.Print(), visitor._sqlTranslator.TranslationErrorDetails));
return null;
}
}

selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();

return new NonQueryExpression(new UpdateExpression(tableExpression, selectExpression, setColumnValues));
}


return null;

void PopulateSetPropertyStatements(
Expression expression, List<(LambdaExpression, LambdaExpression)> list, ParameterExpression parameter)
Expand Down Expand Up @@ -1273,6 +1301,8 @@ when methodCallExpression.Method.IsGenericMethod
}
}



static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out EntityShaperExpression? entityShaperExpression)
{
if (expression is MemberExpression { Expression: EntityShaperExpression ese })
Expand All @@ -1292,6 +1322,18 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
entityShaperExpression = null;
return false;
}

static Expression GetEntitySource(Expression propertyAccessExpression)
{
propertyAccessExpression = propertyAccessExpression.UnwrapTypeConversion(out _);
if (propertyAccessExpression is MethodCallExpression mce
&& mce.TryGetEFPropertyArguments(out var source, out _))
{
return source;
}

return ((MemberExpression)propertyAccessExpression).Expression!;
}
}

/// <summary>
Expand Down
Loading

0 comments on commit 20d353f

Please sign in to comment.