Skip to content

Commit

Permalink
Query: FromSql: Adds support for interpolated strings!
Browse files Browse the repository at this point in the history
Interpolated variables are parameterized and queries are cached in the usual way.
  • Loading branch information
anpete committed May 12, 2017
1 parent 951e482 commit 7a18610
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 13 deletions.
97 changes: 91 additions & 6 deletions src/EFCore.Relational.Specification.Tests/FromSqlQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Specification.Tests.TestModels.Northwind;
using Xunit;
// ReSharper disable FormatStringProblem

// ReSharper disable InconsistentNaming

Expand Down Expand Up @@ -316,8 +317,7 @@ public virtual void From_sql_queryable_multiple_composed_with_parameters_and_clo
using (var context = CreateContext())
{
var actual
= (from c in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""City"" = {0}",
city)
= (from c in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""City"" = {0}", city)
from o in context.Set<Order>().FromSql(@"SELECT * FROM ""Orders"" WHERE ""OrderDate"" BETWEEN {0} AND {1}",
startDate,
endDate)
Expand All @@ -326,6 +326,21 @@ from o in context.Set<Order>().FromSql(@"SELECT * FROM ""Orders"" WHERE ""OrderD
.ToArray();

Assert.Equal(25, actual.Length);

city = "Berlin";
startDate = new DateTime(1998, 4, 1);
endDate = new DateTime(1998, 5, 1);

actual
= (from c in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""City"" = {0}", city)
from o in context.Set<Order>().FromSql(@"SELECT * FROM ""Orders"" WHERE ""OrderDate"" BETWEEN {0} AND {1}",
startDate,
endDate)
where c.CustomerID == o.CustomerID
select new { c, o })
.ToArray();

Assert.Equal(1, actual.Length);
}
}

Expand Down Expand Up @@ -400,6 +415,74 @@ public virtual void From_sql_queryable_with_parameters_inline()
}
}

[Fact]
public virtual void From_sql_queryable_with_parameters_interpolated()
{
var city = "London";
var contactTitle = "Sales Representative";

using (var context = CreateContext())
{
var actual = context.Set<Customer>()
.FromSql(
$@"SELECT * FROM ""Customers"" WHERE ""City"" = {city} AND ""ContactTitle"" = {contactTitle}")
.ToArray();

Assert.Equal(3, actual.Length);
Assert.True(actual.All(c => c.City == "London"));
Assert.True(actual.All(c => c.ContactTitle == "Sales Representative"));
}
}

[Fact]
public virtual void From_sql_queryable_with_parameters_inline_interpolated()
{
using (var context = CreateContext())
{
var actual = context.Set<Customer>()
.FromSql(
$@"SELECT * FROM ""Customers"" WHERE ""City"" = {"London"} AND ""ContactTitle"" = {"Sales Representative"}")
.ToArray();

Assert.Equal(3, actual.Length);
Assert.True(actual.All(c => c.City == "London"));
Assert.True(actual.All(c => c.ContactTitle == "Sales Representative"));
}
}

[Fact]
public virtual void From_sql_queryable_multiple_composed_with_parameters_and_closure_parameters_interpolated()
{
var city = "London";
var startDate = new DateTime(1997, 1, 1);
var endDate = new DateTime(1998, 1, 1);

using (var context = CreateContext())
{
var actual
= (from c in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""City"" = {0}", city)
from o in context.Set<Order>().FromSql($@"SELECT * FROM ""Orders"" WHERE ""OrderDate"" BETWEEN {startDate} AND {endDate}")
where c.CustomerID == o.CustomerID
select new { c, o })
.ToArray();

Assert.Equal(25, actual.Length);

city = "Berlin";
startDate = new DateTime(1998, 4, 1);
endDate = new DateTime(1998, 5, 1);

actual
= (from c in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""City"" = {0}", city)
from o in context.Set<Order>().FromSql($@"SELECT * FROM ""Orders"" WHERE ""OrderDate"" BETWEEN {startDate} AND {endDate}")
where c.CustomerID == o.CustomerID
select new { c, o })
.ToArray();

Assert.Equal(1, actual.Length);
}
}

[Fact]
public virtual void From_sql_queryable_with_null_parameter()
{
Expand Down Expand Up @@ -656,8 +739,7 @@ public virtual void From_sql_with_db_parameters_called_multiple_times()
var parameter = CreateDbParameter("@id", "ALFKI");

var query = context.Customers
.FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = @id",
parameter);
.FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = @id", parameter);

var result1 = query.ToList();

Expand All @@ -675,8 +757,11 @@ public virtual void From_sql_with_SelectMany_and_include()
{
using (var context = CreateContext())
{
var query = from c1 in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = 'ALFKI'")
from c2 in context.Set<Customer>().FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = 'AROUT'").Include(c => c.Orders)
var query = from c1 in context.Set<Customer>()
.FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = 'ALFKI'")
from c2 in context.Set<Customer>()
.FromSql(@"SELECT * FROM ""Customers"" WHERE ""CustomerID"" = 'AROUT'")
.Include(c => c.Orders)
select new { c1, c2 };

var result = query.ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public FromSqlExpressionNode(
[NotNull] Expression arguments)
: base(parseInfo, null, null)
{
_sql = (string)sql.Value;
_sql = ((RelationalQueryableExtensions.SqlFormat)sql.Value).Format;
_arguments = arguments;
}

Expand Down
64 changes: 60 additions & 4 deletions src/EFCore.Relational/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -18,7 +19,8 @@ public static class RelationalQueryableExtensions
{
internal static readonly MethodInfo FromSqlMethodInfo
= typeof(RelationalQueryableExtensions)
.GetTypeInfo().GetDeclaredMethod(nameof(FromSql));
.GetTypeInfo().GetDeclaredMethods(nameof(FromSql))
.Single(mi => mi.GetParameters().Length == 3);

/// <summary>
/// <para>
Expand All @@ -44,18 +46,21 @@ internal static readonly MethodInfo FromSqlMethodInfo
/// <param name="source">
/// An <see cref="IQueryable{T}" /> to use as the base of the raw SQL query (typically a <see cref="DbSet{TEntity}" />).
/// </param>
/// <param name="sql"> The raw SQL query. </param>
/// <param name="sql">
/// The raw SQL query. NB. A string literal may be passed here because <see cref="SqlFormat" />
/// is implicitly convertible to string.
/// </param>
/// <param name="parameters"> The values to be assigned to parameters. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the raw SQL query. </returns>
[StringFormatMethod("sql")]
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] [NotParameterized] string sql,
[NotParameterized] SqlFormat sql,
[NotNull] params object[] parameters)
where TEntity : class
{
Check.NotNull(source, nameof(source));
Check.NotEmpty(sql, nameof(sql));
Check.NotEmpty(sql.Format, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

return source.Provider.CreateQuery<TEntity>(
Expand All @@ -66,5 +71,56 @@ public static IQueryable<TEntity> FromSql<TEntity>(
Expression.Constant(sql),
Expression.Constant(parameters)));
}

/// <summary>
/// <para>
/// Creates a LINQ query based on an interpolated string SQL query.
/// </para>
/// <para>
/// If the database provider supports composing on the supplied SQL, you can compose on top of the raw SQL query using
/// LINQ operators - <code>context.Blogs.FromSql("SELECT * FROM dbo.Blogs").OrderBy(b => b.Name)</code>.
/// </para>
/// <para>
/// As with any API that accepts SQL it is important to parameterize any user input to protect against a SQL injection
/// attack. You can include interpolated parameter place holders in the SQL query string. Any interpolated parameter values
/// you supply will automatically be converted to a DbParameter -
/// <code>context.Blogs.FromSql($"SELECT * FROM [dbo].[SearchBlogs]({userSuppliedSearchTerm})")</code>.
/// </para>
/// </summary>
/// <typeparam name="TEntity"> The type of the elements of <paramref name="source" />. </typeparam>
/// <param name="source">
/// An <see cref="IQueryable{T}" /> to use as the base of the interpolated string SQL query (typically a <see cref="DbSet{TEntity}" />).
/// </param>
/// <param name="sql"> The interpolated string SQL query. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the interpolated string SQL query. </returns>
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] [NotParameterized] FormattableString sql)
where TEntity : class
{
Check.NotNull(source, nameof(source));
Check.NotNull(sql, nameof(sql));
Check.NotEmpty(sql.Format, nameof(source));

return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(new SqlFormat(sql.Format)),
Expression.Constant(sql.GetArguments())));
}

/// <summary>
/// A SQL format string. This type enables overload resolution between
/// the regular and interpolated FromSql overloads.
/// </summary>
public struct SqlFormat
{
public static implicit operator SqlFormat([NotNull] string s) => new SqlFormat(s);
public static implicit operator SqlFormat([NotNull] FormattableString fs) => default(SqlFormat);
public SqlFormat([NotNull] string s) => Format = s;
public string Format { get; }
}
}
}
5 changes: 5 additions & 0 deletions src/EFCore.Relational/breakingchanges.netcore.json
Original file line number Diff line number Diff line change
Expand Up @@ -1021,5 +1021,10 @@
"TypeId": "public interface Microsoft.EntityFrameworkCore.Metadata.IRelationalIndexAnnotations",
"MemberId": "System.String get_Filter()",
"Kind": "Addition"
},
{
"TypeId": "public static class Microsoft.EntityFrameworkCore.RelationalQueryableExtensions",
"MemberId": "public static System.Linq.IQueryable<T0> FromSql<T0>(this System.Linq.IQueryable<T0> source, System.String sql, params System.Object[] parameters) where T0 : class",
"Kind": "Removal"
}
]
5 changes: 5 additions & 0 deletions src/EFCore.Relational/breakingchanges.netframework.json
Original file line number Diff line number Diff line change
Expand Up @@ -1021,5 +1021,10 @@
"TypeId": "public interface Microsoft.EntityFrameworkCore.Metadata.IRelationalIndexAnnotations",
"MemberId": "System.String get_Filter()",
"Kind": "Addition"
},
{
"TypeId": "public static class Microsoft.EntityFrameworkCore.RelationalQueryableExtensions",
"MemberId": "public static System.Linq.IQueryable<T0> FromSql<T0>(this System.Linq.IQueryable<T0> source, System.String sql, params System.Object[] parameters) where T0 : class",
"Kind": "Removal"
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,40 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (newArgument.RemoveConvert() is ParameterExpression parameter)
{
newArgument = Expression.Constant(_parameterValues.RemoveParameter(parameter.Name));
var parameterValue = _parameterValues.RemoveParameter(parameter.Name);

if (parameter.Type == typeof(FormattableString))
{
if (Evaluate(methodCallExpression, out var _) is IQueryable queryable)
{
var oldInLambda = _inLambda;

_inLambda = false;

try
{
return ExtractParameters(queryable.Expression);
}
finally
{
_inLambda = oldInLambda;
}
}
}
else
{
var constantParameterValue = Expression.Constant(parameterValue);

if (newArgument is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
newArgument = unaryExpression.Update(constantParameterValue);
}
else
{
newArgument = constantParameterValue;
}
}
}
}

Expand Down
Loading

0 comments on commit 7a18610

Please sign in to comment.