diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs b/src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs index 8c70471c..853b45c8 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs @@ -264,6 +264,38 @@ protected override void Dispose(bool disposing) protected override DbCommand CreateDbCommand() => CreateCommand(); + /// + /// Create custom collation. + /// + /// Name of the collation. + /// Method that compares two strings. + public virtual void CreateCollation(string name, Comparison comparison) + => CreateCollation(name, null, comparison != null ? (_, s1, s2) => comparison(s1, s2) : (Func)null); + + /// + /// Create custom collation. + /// + /// The type of the state object. + /// Name of the collation. + /// State object passed to each invokation of the collation. + /// Method that compares two strings, using additional state. + public virtual void CreateCollation(string name, T state, Func comparison) + { + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentNullException(nameof(name)); + } + + if (State != ConnectionState.Open) + { + throw new InvalidOperationException(Resources.CallRequiresOpenConnection(nameof(CreateCollation))); + } + + delegate_collation collation = comparison != null ? (v, s1, s2) => comparison((T)v, s1, s2) : (delegate_collation)null; + var rc = raw.sqlite3_create_collation(_db, name, state, collation); + SqliteException.ThrowExceptionForRC(rc, _db); + } + /// /// Begins a transaction on the connection. /// diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs index 3abe90c7..bacf4c66 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs @@ -1,7 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Data; using System.IO; using Microsoft.Data.Sqlite.Properties; @@ -310,6 +311,78 @@ public void CreateCommand_returns_command() } } + [Fact] + public void CreateCollation_works() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + connection.CreateCollation("MY_NOCASE", (s1, s2) => string.Compare(s1, s2, StringComparison.OrdinalIgnoreCase)); + + Assert.Equal(1L, connection.ExecuteScalar("SELECT 'Νικοσ' = 'ΝΙΚΟΣ' COLLATE MY_NOCASE;")); + } + } + + [Fact] + public void CreateCollation_with_null_comparer_works() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + connection.CreateCollation("MY_NOCASE", (s1, s2) => string.Compare(s1, s2, StringComparison.OrdinalIgnoreCase)); + connection.CreateCollation("MY_NOCASE", null); + + var ex = Assert.Throws( + () => connection.ExecuteScalar("SELECT 'Νικοσ' = 'ΝΙΚΟΣ' COLLATE MY_NOCASE;")); + + Assert.Equal(raw.SQLITE_ERROR, ex.SqliteErrorCode); + } + } + + [Fact] + public void CreateCollation_throws_when_closed() + { + var connection = new SqliteConnection(); + + var ex = Assert.Throws(() => connection.CreateCollation("NOCOL", (s1, s2) => -1)); + + Assert.Equal(Resources.CallRequiresOpenConnection("CreateCollation"), ex.Message); + } + + [Fact] + public void CreateCollation_throws_with_empty_name() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + var ex = Assert.Throws(() => connection.CreateCollation(null, null)); + + Assert.Equal("name", ex.ParamName); + } + } + + [Fact] + public void CreateCollation_works_with_state() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + var list = new List(); + connection.CreateCollation>( + "MY_NOCASE", + list, + (l, s1, s2) => + { + l.Add("Invoked"); + return string.Compare(s1, s2, StringComparison.OrdinalIgnoreCase); + }); + + Assert.Equal(1L, connection.ExecuteScalar("SELECT 'Νικοσ' = 'ΝΙΚΟΣ' COLLATE MY_NOCASE;")); + Assert.Equal(1, list.Count); + Assert.Equal("Invoked", list[0]); + } + } + [Fact] public void BeginTransaction_throws_when_closed() {