From afa3a930d2d7fe694aee330f736b6ec6b8e09e95 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 26 Jul 2024 03:31:48 -0700 Subject: [PATCH 01/30] feat: create `sqlx.toml` format --- Cargo.toml | 11 +- sqlx-core/Cargo.toml | 9 +- sqlx-core/src/config/common.rs | 38 ++++ sqlx-core/src/config/macros.rs | 296 ++++++++++++++++++++++++++++ sqlx-core/src/config/migrate.rs | 158 +++++++++++++++ sqlx-core/src/config/mod.rs | 206 +++++++++++++++++++ sqlx-core/src/config/reference.toml | 175 ++++++++++++++++ sqlx-core/src/config/tests.rs | 90 +++++++++ sqlx-core/src/lib.rs | 3 + sqlx-macros-core/Cargo.toml | 4 + sqlx-macros/Cargo.toml | 3 + src/lib.rs | 3 + 12 files changed, 993 insertions(+), 3 deletions(-) create mode 100644 sqlx-core/src/config/common.rs create mode 100644 sqlx-core/src/config/macros.rs create mode 100644 sqlx-core/src/config/migrate.rs create mode 100644 sqlx-core/src/config/mod.rs create mode 100644 sqlx-core/src/config/reference.toml create mode 100644 sqlx-core/src/config/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 316dc471e1..72a9d01c28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,16 +50,21 @@ authors.workspace = true repository.workspace = true [package.metadata.docs.rs] -features = ["all-databases", "_unstable-all-types"] +features = ["all-databases", "_unstable-all-types", "_unstable-doc"] rustdoc-args = ["--cfg", "docsrs"] [features] -default = ["any", "macros", "migrate", "json"] +default = ["any", "macros", "migrate", "json", "config-all"] derive = ["sqlx-macros/derive"] macros = ["derive", "sqlx-macros/macros"] migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] +# Enable parsing of `sqlx.toml` for configuring macros, migrations, or both. +config-macros = ["sqlx-macros?/config-macros"] +config-migrate = ["sqlx-macros?/config-migrate"] +config-all = ["config-macros", "config-migrate"] + # intended mainly for CI and docs all-databases = ["mysql", "sqlite", "postgres", "any"] _unstable-all-types = [ @@ -73,6 +78,8 @@ _unstable-all-types = [ "uuid", "bit-vec", ] +# Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`). +_unstable-doc = ["config-all", "sqlx-core/_unstable-doc"] # Base runtime features without TLS runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index d662861470..f70adde55e 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -12,7 +12,7 @@ features = ["offline"] [features] default = [] -migrate = ["sha2", "crc"] +migrate = ["sha2", "crc", "config-migrate"] any = [] @@ -31,6 +31,12 @@ _tls-none = [] # support offline/decoupled building (enables serialization of `Describe`) offline = ["serde", "either/serde"] +config = ["serde", "toml/parse"] +config-macros = ["config"] +config-migrate = ["config"] + +_unstable-doc = ["config-macros", "config-migrate"] + [dependencies] # Runtimes async-std = { workspace = true, optional = true } @@ -70,6 +76,7 @@ percent-encoding = "2.1.0" regex = { version = "1.5.5", optional = true } serde = { version = "1.0.132", features = ["derive", "rc"], optional = true } serde_json = { version = "1.0.73", features = ["raw_value"], optional = true } +toml = { version = "0.8.16", optional = true } sha2 = { version = "0.10.0", default-features = false, optional = true } #sqlformat = "0.2.0" thiserror = "2.0.0" diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs new file mode 100644 index 0000000000..8c774fc60f --- /dev/null +++ b/sqlx-core/src/config/common.rs @@ -0,0 +1,38 @@ +/// Configuration shared by multiple components. +#[derive(Debug, Default, serde::Deserialize)] +pub struct Config { + /// Override the database URL environment variable. + /// + /// This is used by both the macros and `sqlx-cli`. + /// + /// Case-sensitive. Defaults to `DATABASE_URL`. + /// + /// Example: Multi-Database Project + /// ------- + /// You can use multiple databases in the same project by breaking it up into multiple crates, + /// then using a different environment variable for each. + /// + /// For example, with two crates in the workspace named `foo` and `bar`: + /// + /// #### `foo/sqlx.toml` + /// ```toml + /// [macros] + /// database_url_var = "FOO_DATABASE_URL" + /// ``` + /// + /// #### `bar/sqlx.toml` + /// ```toml + /// [macros] + /// database_url_var = "BAR_DATABASE_URL" + /// ``` + /// + /// #### `.env` + /// ```text + /// FOO_DATABASE_URL=postgres://postgres@localhost:5432/foo + /// BAR_DATABASE_URL=postgres://postgres@localhost:5432/bar + /// ``` + /// + /// The query macros used in `foo` will use `FOO_DATABASE_URL`, + /// and the ones used in `bar` will use `BAR_DATABASE_URL`. + pub database_url_var: Option, +} diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs new file mode 100644 index 0000000000..5edd30dc15 --- /dev/null +++ b/sqlx-core/src/config/macros.rs @@ -0,0 +1,296 @@ +use std::collections::BTreeMap; + +/// Configuration for the `query!()` family of macros. +#[derive(Debug, Default, serde::Deserialize)] +#[serde(default)] +pub struct Config { + /// Specify the crate to use for mapping date/time types to Rust. + /// + /// The default behavior is to use whatever crate is enabled, + /// [`chrono`] or [`time`] (the latter takes precedent). + /// + /// [`chrono`]: crate::types::chrono + /// [`time`]: crate::types::time + /// + /// Example: Always Use Chrono + /// ------- + /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable + /// the `time` feature of SQLx which will force it on for all crates using SQLx, + /// which will result in problems if your crate wants to use types from [`chrono`]. + /// + /// You can use the type override syntax (see `sqlx::query!` for details), + /// or you can force an override globally by setting this option. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros] + /// datetime_crate = "chrono" + /// ``` + /// + /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + pub datetime_crate: DateTimeCrate, + + /// Specify global overrides for mapping SQL type names to Rust type names. + /// + /// Default type mappings are defined by the database driver. + /// Refer to the `sqlx::types` module for details. + /// + /// ## Note: Orthogonal to Nullability + /// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` + /// or not. They only override the inner type used. + /// + /// ## Note: Schema Qualification (Postgres) + /// Type names may be schema-qualified in Postgres. If so, the schema should be part + /// of the type string, e.g. `'foo.bar'` to reference type `bar` in schema `foo`. + /// + /// The schema and/or type name may additionally be quoted in the string + /// for a quoted identifier (see next section). + /// + /// Schema qualification should not be used for types in the search path. + /// + /// ## Note: Quoted Identifiers (Postgres) + /// Type names using [quoted identifiers in Postgres] must also be specified with quotes here. + /// + /// Note, however, that the TOML format parses way the outer pair of quotes, + /// so for quoted names in Postgres, double-quoting is necessary, + /// e.g. `'"Foo"'` for SQL type `"Foo"`. + /// + /// To reference a schema-qualified type with a quoted name, use double-quotes after the + /// dot, e.g. `'foo."Bar"'` to reference type `"Bar"` of schema `foo`, and vice versa for + /// quoted schema names. + /// + /// We recommend wrapping all type names in single quotes, as shown below, + /// to avoid confusion. + /// + /// MySQL/MariaDB and SQLite do not support custom types, so quoting type names should + /// never be necessary. + /// + /// [quoted identifiers in Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + // Note: we wanted to be able to handle this intelligently, + // but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761 + // + // We decided to just encourage always quoting type names instead. + /// Example: Custom Wrapper Types + /// ------- + /// Does SQLx not support a type that you need? Do you want additional semantics not + /// implemented on the built-in types? You can create a custom wrapper, + /// or use an external crate. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.type_overrides] + /// # Override a built-in type + /// 'uuid' = "crate::types::MyUuid" + /// + /// # Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) + /// # (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) + /// 'isbn13' = "isn_rs::sqlx::ISBN13" + /// ``` + /// + /// Example: Custom Types in Postgres + /// ------- + /// If you have a custom type in Postgres that you want to map without needing to use + /// the type override syntax in `sqlx::query!()` every time, you can specify a global + /// override here. + /// + /// For example, a custom enum type `foo`: + /// + /// #### Migration or Setup SQL (e.g. `migrations/0_setup.sql`) + /// ```sql + /// CREATE TYPE foo AS ENUM ('Bar', 'Baz'); + /// ``` + /// + /// #### `src/types.rs` + /// ```rust,no_run + /// #[derive(sqlx::Type)] + /// pub enum Foo { + /// Bar, + /// Baz + /// } + /// ``` + /// + /// If you're not using `PascalCase` in your enum variants then you'll want to use + /// `#[sqlx(rename_all = "")]` on your enum. + /// See [`Type`][crate::type::Type] for details. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.type_overrides] + /// # Map SQL type `foo` to `crate::types::Foo` + /// 'foo' = "crate::types::Foo" + /// ``` + /// + /// Example: Schema-Qualified Types + /// ------- + /// (See `Note` section above for details.) + /// + /// ```toml + /// [macros.type_overrides] + /// # Map SQL type `foo.foo` to `crate::types::Foo` + /// 'foo.foo' = "crate::types::Foo" + /// ``` + /// + /// Example: Quoted Identifiers + /// ------- + /// If a type or schema uses quoted identifiers, + /// it must be wrapped in quotes _twice_ for SQLx to know the difference: + /// + /// ```toml + /// [macros.type_overrides] + /// # `"Foo"` in SQLx + /// '"Foo"' = "crate::types::Foo" + /// # **NOT** `"Foo"` in SQLx (parses as just `Foo`) + /// "Foo" = "crate::types::Foo" + /// + /// # Schema-qualified + /// '"foo".foo' = "crate::types::Foo" + /// 'foo."Foo"' = "crate::types::Foo" + /// '"foo"."Foo"' = "crate::types::Foo" + /// ``` + /// + /// (See `Note` section above for details.) + pub type_overrides: BTreeMap, + + /// Specify per-column overrides for mapping SQL types to Rust types. + /// + /// Default type mappings are defined by the database driver. + /// Refer to the `sqlx::types` module for details. + /// + /// The supported syntax is similar to [`type_overrides`][Self::type_overrides], + /// (with the same caveat for quoted names!) but column names must be qualified + /// by a separately quoted table name, which may optionally be schema-qualified. + /// + /// Multiple columns for the same SQL table may be written in the same table in TOML + /// (see examples below). + /// + /// ## Note: Orthogonal to Nullability + /// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` + /// or not. They only override the inner type used. + /// + /// ## Note: Schema Qualification + /// Table names may be schema-qualified. If so, the schema should be part + /// of the table name string, e.g. `'foo.bar'` to reference table `bar` in schema `foo`. + /// + /// The schema and/or type name may additionally be quoted in the string + /// for a quoted identifier (see next section). + /// + /// Postgres users: schema qualification should not be used for tables in the search path. + /// + /// ## Note: Quoted Identifiers + /// Schema, table, or column names using quoted identifiers ([MySQL], [Postgres], [SQLite]) + /// in SQL must also be specified with quotes here. + /// + /// Postgres and SQLite use double-quotes (`"Foo"`) while MySQL uses backticks (`\`Foo\`). + /// + /// Note, however, that the TOML format parses way the outer pair of quotes, + /// so for quoted names in Postgres, double-quoting is necessary, + /// e.g. `'"Foo"'` for SQL name `"Foo"`. + /// + /// To reference a schema-qualified table with a quoted name, use the appropriate quotation + /// characters after the dot, e.g. `'foo."Bar"'` to reference table `"Bar"` of schema `foo`, + /// and vice versa for quoted schema names. + /// + /// We recommend wrapping all table and column names in single quotes, as shown below, + /// to avoid confusion. + /// + /// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/identifiers.html + /// [Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + /// [SQLite]: https://sqlite.org/lang_keywords.html + // Note: we wanted to be able to handle this intelligently, + // but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761 + // + // We decided to just encourage always quoting type names instead. + /// + /// Example + /// ------- + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.column_overrides.'foo'] + /// # Map column `bar` of table `foo` to Rust type `crate::types::Foo`: + /// 'bar' = "crate::types::Bar" + /// + /// # Quoted column name + /// # Note: same quoting requirements as `macros.type_overrides` + /// '"Bar"' = "crate::types::Bar" + /// + /// # Note: will NOT work (parses as `Bar`) + /// # "Bar" = "crate::types::Bar" + /// + /// # Table name may be quoted (note the wrapping single-quotes) + /// [macros.column_overrides.'"Foo"'] + /// 'bar' = "crate::types::Bar" + /// '"Bar"' = "crate::types::Bar" + /// + /// # Table name may also be schema-qualified. + /// # Note how the dot is inside the quotes. + /// [macros.column_overrides.'my_schema.my_table'] + /// 'my_column' = "crate::types::MyType" + /// + /// # Quoted schema, table, and column names + /// [macros.column_overrides.'"My Schema"."My Table"'] + /// '"My Column"' = "crate::types::MyType" + /// ``` + pub column_overrides: BTreeMap>, +} + +/// The crate to use for mapping date/time types to Rust. +#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DateTimeCrate { + /// Use whichever crate is enabled (`time` then `chrono`). + #[default] + Inferred, + + /// Always use types from [`chrono`][crate::types::chrono]. + /// + /// ```toml + /// [macros] + /// datetime_crate = "chrono" + /// ``` + Chrono, + + /// Always use types from [`time`][crate::types::time]. + /// + /// ```toml + /// [macros] + /// datetime_crate = "time" + /// ``` + Time, +} + +/// A SQL type name; may optionally be schema-qualified. +/// +/// See [`macros.type_overrides`][Config::type_overrides] for usages. +pub type SqlType = Box; + +/// A SQL table name; may optionally be schema-qualified. +/// +/// See [`macros.column_overrides`][Config::column_overrides] for usages. +pub type TableName = Box; + +/// A column in a SQL table. +/// +/// See [`macros.column_overrides`][Config::column_overrides] for usages. +pub type ColumnName = Box; + +/// A Rust type name or path. +/// +/// Should be a global path (not relative). +pub type RustType = Box; + +/// Internal getter methods. +impl Config { + /// Get the override for a given type name (optionally schema-qualified). + pub fn type_override(&self, type_name: &str) -> Option<&str> { + self.type_overrides.get(type_name).map(|s| &**s) + } + + /// Get the override for a given column and table name (optionally schema-qualified). + pub fn column_override(&self, table: &str, column: &str) -> Option<&str> { + self.column_overrides + .get(table) + .and_then(|by_column| by_column.get(column)) + .map(|s| &**s) + } +} diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs new file mode 100644 index 0000000000..5878f9a24f --- /dev/null +++ b/sqlx-core/src/config/migrate.rs @@ -0,0 +1,158 @@ +use std::collections::BTreeSet; + +/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +/// +/// ### Note +/// A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these +/// configuration options. We recommend using `sqlx::migrate!()` instead. +/// +/// ### Warning: Potential Data Loss or Corruption! +/// Many of these options, if changed after migrations are set up, +/// can result in data loss or corruption of a production database +/// if the proper precautions are not taken. +/// +/// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. +#[derive(Debug, Default, serde::Deserialize)] +#[serde(default)] +pub struct Config { + /// Override the name of the table used to track executed migrations. + /// + /// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. + /// + /// Potentially useful for multi-tenant databases. + /// + /// ### Warning: Potential Data Loss or Corruption! + /// Changing this option for a production database will likely result in data loss or corruption + /// as the migration machinery will no longer be aware of what migrations have been applied + /// and will attempt to re-run them. + /// + /// You should create the new table as a copy of the existing migrations table (with contents!), + /// and be sure all instances of your application have been migrated to the new + /// table before deleting the old one. + /// + /// ### Example + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// # Put `_sqlx_migrations` in schema `foo` + /// table_name = "foo._sqlx_migrations" + /// ``` + pub table_name: Option>, + + /// Override the directory used for migrations files. + /// + /// Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`. + pub migrations_dir: Option>, + + /// Specify characters that should be ignored when hashing migrations. + /// + /// Any characters contained in the given array will be dropped when a migration is hashed. + /// + /// ### Warning: May Change Hashes for Existing Migrations + /// Changing the characters considered in hashing migrations will likely + /// change the output of the hash. + /// + /// This may require manual rectification for deployed databases. + /// + /// ### Example: Ignore Carriage Return (`` | `\r`) + /// Line ending differences between platforms can result in migrations having non-repeatable + /// hashes. The most common culprit is the carriage return (`` | `\r`), which Windows + /// uses in its line endings alongside line feed (`` | `\n`), often written `CRLF` or `\r\n`, + /// whereas Linux and macOS use only line feeds. + /// + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// ignored_chars = ["\r"] + /// ``` + /// + /// For projects using Git, this can also be addressed using [`.gitattributes`]: + /// + /// ```text + /// # Force newlines in migrations to be line feeds on all platforms + /// migrations/*.sql text eol=lf + /// ``` + /// + /// This may require resetting or re-checking out the migrations files to take effect. + /// + /// [`.gitattributes`]: https://git-scm.com/docs/gitattributes + /// + /// ### Example: Ignore all Whitespace Characters + /// To make your migrations amenable to reformatting, you may wish to tell SQLx to ignore + /// _all_ whitespace characters in migrations. + /// + /// ##### Warning: Beware Syntatically Significant Whitespace! + /// If your migrations use string literals or quoted identifiers which contain whitespace, + /// this configuration will cause the migration machinery to ignore some changes to these. + /// This may result in a mismatch between the development and production versions of + /// your database. + /// + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// # Ignore common whitespace characters when hashing + /// ignored_chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF + /// ``` + // Likely lower overhead for small sets than `HashSet`. + pub ignored_chars: BTreeSet, + + /// Specify the default type of migration that `sqlx migrate create` should create by default. + /// + /// ### Example: Use Reversible Migrations by Default + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// default_type = "reversible" + /// ``` + pub default_type: DefaultMigrationType, + + /// Specify the default scheme that `sqlx migrate create` should use for version integers. + /// + /// ### Example: Use Sequential Versioning by Default + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// default_versioning = "sequential" + /// ``` + pub default_versioning: DefaultVersioning, +} + +/// The default type of migration that `sqlx migrate create` should create by default. +#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DefaultMigrationType { + /// Create the same migration type as that of the latest existing migration, + /// or `Simple` otherwise. + #[default] + Inferred, + + /// Create a non-reversible migration (`_.sql`). + Simple, + + /// Create a reversible migration (`_.up.sql` and `[...].down.sql`). + Reversible, +} + +/// The default scheme that `sqlx migrate create` should use for version integers. +#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DefaultVersioning { + /// Infer the versioning scheme from existing migrations: + /// + /// * If the versions of the last two migrations differ by `1`, infer `Sequential`. + /// * If only one migration exists and has version `1`, infer `Sequential`. + /// * Otherwise, infer `Timestamp`. + #[default] + Inferred, + + /// Use UTC timestamps for migration versions. + /// + /// This is the recommended versioning format as it's less likely to collide when multiple + /// developers are creating migrations on different branches. + /// + /// The exact timestamp format is unspecified. + Timestamp, + + /// Use sequential integers for migration versions. + Sequential, +} diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs new file mode 100644 index 0000000000..979477241f --- /dev/null +++ b/sqlx-core/src/config/mod.rs @@ -0,0 +1,206 @@ +//! (Exported for documentation only) Guide and reference for `sqlx.toml` files. +//! +//! To use, create a `sqlx.toml` file in your crate root (the same directory as your `Cargo.toml`). +//! The configuration in a `sqlx.toml` configures SQLx *only* for the current crate. +//! +//! See the [`Config`] type and its fields for individual configuration options. +//! +//! See the [reference][`_reference`] for the full `sqlx.toml` file. + +use std::fmt::Debug; +use std::io; +use std::path::{Path, PathBuf}; + +// `std::sync::OnceLock` doesn't have a stable `.get_or_try_init()` +// because it's blocked on a stable `Try` trait. +use once_cell::sync::OnceCell; + +/// Configuration shared by multiple components. +/// +/// See [`common::Config`] for details. +pub mod common; + +/// Configuration for the `query!()` family of macros. +/// +/// See [`macros::Config`] for details. +#[cfg(feature = "config-macros")] +pub mod macros; + +/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +/// +/// See [`migrate::Config`] for details. +#[cfg(feature = "config-migrate")] +pub mod migrate; + +/// Reference for `sqlx.toml` files +/// +/// Source: `sqlx-core/src/config/reference.toml` +/// +/// ```toml +#[doc = include_str!("reference.toml")] +/// ``` +pub mod _reference {} + +#[cfg(test)] +mod tests; + +/// The parsed structure of a `sqlx.toml` file. +#[derive(Debug, Default, serde::Deserialize)] +pub struct Config { + /// Configuration shared by multiple components. + /// + /// See [`common::Config`] for details. + pub common: common::Config, + + /// Configuration for the `query!()` family of macros. + /// + /// See [`macros::Config`] for details. + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "config-all", feature = "config-macros"))) + )] + #[cfg(feature = "config-macros")] + pub macros: macros::Config, + + /// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. + /// + /// See [`migrate::Config`] for details. + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "config-all", feature = "config-migrate"))) + )] + #[cfg(feature = "config-migrate")] + pub migrate: migrate::Config, +} + +/// Error returned from various methods of [`Config`]. +#[derive(thiserror::Error, Debug)] +pub enum ConfigError { + /// The loading method expected `CARGO_MANIFEST_DIR` to be set and it wasn't. + /// + /// This is necessary to locate the root of the crate currently being compiled. + /// + /// See [the "Environment Variables" page of the Cargo Book][cargo-env] for details. + /// + /// [cargo-env]: https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates + #[error("environment variable `CARGO_MANIFEST_DIR` must be set and valid")] + Env( + #[from] + #[source] + std::env::VarError, + ), + + /// An I/O error occurred while attempting to read the config file at `path`. + /// + /// This includes [`io::ErrorKind::NotFound`]. + /// + /// [`Self::not_found_path()`] will return the path if the file was not found. + #[error("error reading config file {path:?}")] + Read { + path: PathBuf, + #[source] + error: io::Error, + }, + + /// An error in the TOML was encountered while parsing the config file at `path`. + /// + /// The error gives line numbers and context when printed with `Display`/`ToString`. + #[error("error parsing config file {path:?}")] + Parse { + path: PathBuf, + #[source] + error: toml::de::Error, + }, +} + +impl ConfigError { + /// If this error means the file was not found, return the path that was attempted. + pub fn not_found_path(&self) -> Option<&Path> { + match self { + ConfigError::Read { path, error } if error.kind() == io::ErrorKind::NotFound => { + Some(path) + } + _ => None, + } + } +} + +static CACHE: OnceCell = OnceCell::new(); + +/// Internal methods for loading a `Config`. +#[allow(clippy::result_large_err)] +impl Config { + /// Get the cached config, or attempt to read `$CARGO_MANIFEST_DIR/sqlx.toml`. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Returns `Config::default()` if the file does not exist. + /// + /// ### Panics + /// If the file exists but an unrecoverable error was encountered while parsing it. + pub fn from_crate() -> &'static Self { + Self::try_from_crate().unwrap_or_else(|e| { + if let Some(path) = e.not_found_path() { + // Non-fatal + tracing::debug!("Not reading config, file {path:?} not found (error: {e})"); + CACHE.get_or_init(Config::default) + } else { + // In the case of migrations, + // we can't proceed with defaults as they may be completely wrong. + panic!("failed to read sqlx config: {e}") + } + }) + } + + /// Get the cached config, or to read `$CARGO_MANIFEST_DIR/sqlx.toml`. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if `CARGO_MANIFEST_DIR` is not set, or if the config file could not be read. + pub fn try_from_crate() -> Result<&'static Self, ConfigError> { + Self::try_get_with(|| { + let mut path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR")?); + path.push("sqlx.toml"); + Ok(path) + }) + } + + /// Get the cached config, or attempt to read `sqlx.toml` from the current working directory. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if the config file does not exist, or could not be read. + pub fn try_from_current_dir() -> Result<&'static Self, ConfigError> { + Self::try_get_with(|| Ok("sqlx.toml".into())) + } + + /// Get the cached config, or attempt to read it from the path returned by the closure. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Errors if the config file does not exist, or could not be read. + pub fn try_get_with( + make_path: impl FnOnce() -> Result, + ) -> Result<&'static Self, ConfigError> { + CACHE.get_or_try_init(|| { + let path = make_path()?; + Self::read_from(path) + }) + } + + fn read_from(path: PathBuf) -> Result { + // The `toml` crate doesn't provide an incremental reader. + let toml_s = match std::fs::read_to_string(&path) { + Ok(toml) => toml, + Err(error) => { + return Err(ConfigError::Read { path, error }); + } + }; + + // TODO: parse and lint TOML structure before deserializing + // Motivation: https://github.com/toml-rs/toml/issues/761 + tracing::debug!("read config TOML from {path:?}:\n{toml_s}"); + + toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { path, error }) + } +} diff --git a/sqlx-core/src/config/reference.toml b/sqlx-core/src/config/reference.toml new file mode 100644 index 0000000000..fae92f3422 --- /dev/null +++ b/sqlx-core/src/config/reference.toml @@ -0,0 +1,175 @@ +# `sqlx.toml` reference. +# +# Note: shown values are *not* defaults. +# They are explicitly set to non-default values to test parsing. +# Refer to the comment for a given option for its default value. + +############################################################################################### + +# Configuration shared by multiple components. +[common] +# Change the environment variable to get the database URL. +# +# This is used by both the macros and `sqlx-cli`. +# +# If not specified, defaults to `DATABASE_URL` +database_url_var = "FOO_DATABASE_URL" + +############################################################################################### + +# Configuration for the `query!()` family of macros. +[macros] +# Force the macros to use the `chrono` crate for date/time types, even if `time` is enabled. +# +# Defaults to "inferred": use whichever crate is enabled (`time` takes precedence over `chrono`). +datetime_crate = "chrono" + +# Or, ensure the macros always prefer `time` +# in case new date/time crates are added in the future: +# datetime_crate = "time" + +# Set global overrides for mapping SQL types to Rust types. +# +# Default type mappings are defined by the database driver. +# Refer to the `sqlx::types` module for details. +# +# Postgres users: schema qualification should not be used for types in the search path. +# +# ### Note: Orthogonal to Nullability +# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` +# or not. They only override the inner type used. +[macros.type_overrides] +# Override a built-in type (map all `UUID` columns to `crate::types::MyUuid`) +'uuid' = "crate::types::MyUuid" + +# Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) +# (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) +'isbn13' = "isn_rs::isbn::ISBN13" + +# SQL type `foo` to Rust type `crate::types::Foo`: +'foo' = "crate::types::Foo" + +# SQL type `"Bar"` to Rust type `crate::types::Bar`; notice the extra pair of quotes: +'"Bar"' = "crate::types::Bar" + +# Will NOT work (the first pair of quotes are parsed by TOML) +# "Bar" = "crate::types::Bar" + +# Schema qualified +'foo.bar' = "crate::types::Bar" + +# Schema qualified and quoted +'foo."Bar"' = "crate::schema::foo::Bar" + +# Quoted schema name +'"Foo".bar' = "crate::schema::foo::Bar" + +# Quoted schema and type name +'"Foo"."Bar"' = "crate::schema::foo::Bar" + +# Set per-column overrides for mapping SQL types to Rust types. +# +# Note: table name is required in the header. +# +# Postgres users: schema qualification should not be used for types in the search path. +# +# ### Note: Orthogonal to Nullability +# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` +# or not. They only override the inner type used. +[macros.column_overrides.'foo'] +# Map column `bar` of table `foo` to Rust type `crate::types::Foo`: +'bar' = "crate::types::Bar" + +# Quoted column name +# Note: same quoting requirements as `macros.type_overrides` +'"Bar"' = "crate::types::Bar" + +# Note: will NOT work (parses as `Bar`) +# "Bar" = "crate::types::Bar" + +# Table name may be quoted (note the wrapping single-quotes) +[macros.column_overrides.'"Foo"'] +'bar' = "crate::types::Bar" +'"Bar"' = "crate::types::Bar" + +# Table name may also be schema-qualified. +# Note how the dot is inside the quotes. +[macros.column_overrides.'my_schema.my_table'] +'my_column' = "crate::types::MyType" + +# Quoted schema, table, and column names +[macros.column_overrides.'"My Schema"."My Table"'] +'"My Column"' = "crate::types::MyType" + +############################################################################################### + +# Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. +# +# ### Note +# A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these +# configuration options. We recommend using `sqlx::migrate!()` instead. +# +# ### Warning: Potential Data Loss or Corruption! +# Many of these options, if changed after migrations are set up, +# can result in data loss or corruption of a production database +# if the proper precautions are not taken. +# +# Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. +[migrate] +# Override the name of the table used to track executed migrations. +# +# May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. +# +# Potentially useful for multi-tenant databases. +# +# ### Warning: Potential Data Loss or Corruption! +# Changing this option for a production database will likely result in data loss or corruption +# as the migration machinery will no longer be aware of what migrations have been applied +# and will attempt to re-run them. +# +# You should create the new table as a copy of the existing migrations table (with contents!), +# and be sure all instances of your application have been migrated to the new +# table before deleting the old one. +table_name = "foo._sqlx_migrations" + +# Override the directory used for migrations files. +# +# Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`. +migrations_dir = "foo/migrations" + +# Specify characters that should be ignored when hashing migrations. +# +# Any characters contained in the given set will be dropped when a migration is hashed. +# +# Defaults to an empty array (don't drop any characters). +# +# ### Warning: May Change Hashes for Existing Migrations +# Changing the characters considered in hashing migrations will likely +# change the output of the hash. +# +# This may require manual rectification for deployed databases. +# ignored_chars = [] + +# Ignore Carriage Returns (`` | `\r`) +# Note that the TOML format requires double-quoted strings to process escapes. +# ignored_chars = ["\r"] + +# Ignore common whitespace characters (beware syntatically significant whitespace!) +ignored_chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF + +# Specify reversible migrations by default (for `sqlx migrate create`). +# +# Defaults to "inferred": uses the type of the last migration, or "simple" otherwise. +default_type = "reversible" + +# Specify simple (non-reversible) migrations by default. +# default_type = "simple" + +# Specify sequential versioning by default (for `sqlx migrate create`). +# +# Defaults to "inferred": guesses the versioning scheme from the latest migrations, +# or "timestamp" otherwise. +default_versioning = "sequential" + +# Specify timestamp versioning by default. +# default_versioning = "timestamp" diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs new file mode 100644 index 0000000000..bf042069a2 --- /dev/null +++ b/sqlx-core/src/config/tests.rs @@ -0,0 +1,90 @@ +use crate::config::{self, Config}; +use std::collections::BTreeSet; + +#[test] +fn reference_parses_as_config() { + let config: Config = toml::from_str(include_str!("reference.toml")) + // The `Display` impl of `toml::Error` is *actually* more useful than `Debug` + .unwrap_or_else(|e| panic!("expected reference.toml to parse as Config: {e}")); + + assert_common_config(&config.common); + + #[cfg(feature = "config-macros")] + assert_macros_config(&config.macros); + + #[cfg(feature = "config-migrate")] + assert_migrate_config(&config.migrate); +} + +fn assert_common_config(config: &config::common::Config) { + assert_eq!(config.database_url_var.as_deref(), Some("FOO_DATABASE_URL")); +} + +#[cfg(feature = "config-macros")] +fn assert_macros_config(config: &config::macros::Config) { + use config::macros::*; + + assert_eq!(config.datetime_crate, DateTimeCrate::Chrono); + + // Type overrides + // Don't need to cover everything, just some important canaries. + assert_eq!(config.type_override("foo"), Some("crate::types::Foo")); + + assert_eq!(config.type_override(r#""Bar""#), Some("crate::types::Bar"),); + + assert_eq!( + config.type_override(r#""Foo".bar"#), + Some("crate::schema::foo::Bar"), + ); + + assert_eq!( + config.type_override(r#""Foo"."Bar""#), + Some("crate::schema::foo::Bar"), + ); + + // Column overrides + assert_eq!( + config.column_override("foo", "bar"), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override("foo", r#""Bar""#), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override(r#""Foo""#, "bar"), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override(r#""Foo""#, r#""Bar""#), + Some("crate::types::Bar"), + ); + + assert_eq!( + config.column_override("my_schema.my_table", "my_column"), + Some("crate::types::MyType"), + ); + + assert_eq!( + config.column_override(r#""My Schema"."My Table""#, r#""My Column""#), + Some("crate::types::MyType"), + ); +} + +#[cfg(feature = "config-migrate")] +fn assert_migrate_config(config: &config::migrate::Config) { + use config::migrate::*; + + assert_eq!(config.table_name.as_deref(), Some("foo._sqlx_migrations")); + assert_eq!(config.migrations_dir.as_deref(), Some("foo/migrations")); + + let ignored_chars = BTreeSet::from([' ', '\t', '\r', '\n']); + + assert_eq!(config.ignored_chars, ignored_chars); + + assert_eq!(config.default_type, DefaultMigrationType::Reversible); + assert_eq!(config.default_versioning, DefaultVersioning::Sequential); +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index df4b2cc27d..8b831ecaff 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -91,6 +91,9 @@ pub mod any; #[cfg(feature = "migrate")] pub mod testing; +#[cfg(feature = "config")] +pub mod config; + pub use error::{Error, Result}; pub use either::Either; diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 46786b7d8d..46b7dfbf93 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -26,6 +26,10 @@ derive = [] macros = [] migrate = ["sqlx-core/migrate"] +config = ["sqlx-core/config"] +config-macros = ["config", "sqlx-core/config-macros"] +config-migrate = ["config", "sqlx-core/config-migrate"] + # database mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 5617d3f251..1d1b0bcd48 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -27,6 +27,9 @@ derive = ["sqlx-macros-core/derive"] macros = ["sqlx-macros-core/macros"] migrate = ["sqlx-macros-core/migrate"] +config-macros = ["sqlx-macros-core/config-macros"] +config-migrate = ["sqlx-macros-core/config-migrate"] + # database mysql = ["sqlx-macros-core/mysql"] postgres = ["sqlx-macros-core/postgres"] diff --git a/src/lib.rs b/src/lib.rs index 870fa703c5..19142f6666 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,3 +164,6 @@ pub mod prelude { pub use super::Statement; pub use super::Type; } + +#[cfg(feature = "_unstable-doc")] +pub use sqlx_core::config; From 062a06fc78469ac20c2e6af4bf24e976c0ba02c4 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 9 Sep 2024 00:24:01 -0700 Subject: [PATCH 02/30] feat: add support for ignored_chars config to sqlx_core::migrate --- sqlx-core/src/migrate/migration.rs | 59 +++++++++- sqlx-core/src/migrate/migrator.rs | 22 ++++ sqlx-core/src/migrate/mod.rs | 4 +- sqlx-core/src/migrate/source.rs | 171 +++++++++++++++++++++++++++-- 4 files changed, 239 insertions(+), 17 deletions(-) diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 9bd7f569d8..df7a11d78b 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -1,6 +1,5 @@ -use std::borrow::Cow; - use sha2::{Digest, Sha384}; +use std::borrow::Cow; use super::MigrationType; @@ -22,8 +21,26 @@ impl Migration { sql: Cow<'static, str>, no_tx: bool, ) -> Self { - let checksum = Cow::Owned(Vec::from(Sha384::digest(sql.as_bytes()).as_slice())); + let checksum = checksum(&sql); + + Self::with_checksum( + version, + description, + migration_type, + sql, + checksum.into(), + no_tx, + ) + } + pub(crate) fn with_checksum( + version: i64, + description: Cow<'static, str>, + migration_type: MigrationType, + sql: Cow<'static, str>, + checksum: Cow<'static, [u8]>, + no_tx: bool, + ) -> Self { Migration { version, description, @@ -40,3 +57,39 @@ pub struct AppliedMigration { pub version: i64, pub checksum: Cow<'static, [u8]>, } + +pub fn checksum(sql: &str) -> Vec { + Vec::from(Sha384::digest(sql).as_slice()) +} + +pub fn checksum_fragments<'a>(fragments: impl Iterator) -> Vec { + let mut digest = Sha384::new(); + + for fragment in fragments { + digest.update(fragment); + } + + digest.finalize().to_vec() +} + +#[test] +fn fragments_checksum_equals_full_checksum() { + // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` + let sql = "\ + create table comment (\r\n\ + \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ + \tpost_id uuid not null references post(post_id),\r\n\ + \tuser_id uuid not null references \"user\"(user_id),\r\n\ + \tcontent text not null,\r\n\ + \tcreated_at timestamptz not null default now()\r\n\ + );\r\n\ + \r\n\ + create index on comment(post_id, created_at);\r\n\ + "; + + // Should yield a string for each character + let fragments_checksum = checksum_fragments(sql.split("")); + let full_checksum = checksum(sql); + + assert_eq!(fragments_checksum, full_checksum); +} diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 3209ba6e45..42cc3095f8 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -23,6 +23,8 @@ pub struct Migrator { pub locking: bool, #[doc(hidden)] pub no_tx: bool, + #[doc(hidden)] + pub table_name: Cow<'static, str>, } fn validate_applied_migrations( @@ -51,6 +53,7 @@ impl Migrator { ignore_missing: false, no_tx: false, locking: true, + table_name: Cow::Borrowed("_sqlx_migrations"), }; /// Creates a new instance with the given source. @@ -81,6 +84,25 @@ impl Migrator { }) } + /// Override the name of the table used to track executed migrations. + /// + /// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. + /// + /// Potentially useful for multi-tenant databases. + /// + /// ### Warning: Potential Data Loss or Corruption! + /// Changing this option for a production database will likely result in data loss or corruption + /// as the migration machinery will no longer be aware of what migrations have been applied + /// and will attempt to re-run them. + /// + /// You should create the new table as a copy of the existing migrations table (with contents!), + /// and be sure all instances of your application have been migrated to the new + /// table before deleting the old one. + pub fn dangerous_set_table_name(&mut self, table_name: impl Into>) -> &Self { + self.table_name = table_name.into(); + self + } + /// Specify whether applied migrations that are missing from the resolved migrations should be ignored. pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self { self.ignore_missing = ignore_missing; diff --git a/sqlx-core/src/migrate/mod.rs b/sqlx-core/src/migrate/mod.rs index f035b8d3c1..39347cf421 100644 --- a/sqlx-core/src/migrate/mod.rs +++ b/sqlx-core/src/migrate/mod.rs @@ -11,7 +11,7 @@ pub use migrate::{Migrate, MigrateDatabase}; pub use migration::{AppliedMigration, Migration}; pub use migration_type::MigrationType; pub use migrator::Migrator; -pub use source::MigrationSource; +pub use source::{MigrationSource, ResolveConfig, ResolveWith}; #[doc(hidden)] -pub use source::resolve_blocking; +pub use source::{resolve_blocking, resolve_blocking_with_config}; diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index d0c23b43cd..6c3d780bb3 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,8 +1,9 @@ use crate::error::BoxDynError; -use crate::migrate::{Migration, MigrationType}; +use crate::migrate::{migration, Migration, MigrationType}; use futures_core::future::BoxFuture; use std::borrow::Cow; +use std::collections::BTreeSet; use std::fmt::Debug; use std::fs; use std::io; @@ -28,19 +29,48 @@ pub trait MigrationSource<'s>: Debug { impl<'s> MigrationSource<'s> for &'s Path { fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> { + // Behavior changed from previous because `canonicalize()` is potentially blocking + // since it might require going to disk to fetch filesystem data. + self.to_owned().resolve() + } +} + +impl MigrationSource<'static> for PathBuf { + fn resolve(self) -> BoxFuture<'static, Result, BoxDynError>> { + // Technically this could just be `Box::pin(spawn_blocking(...))` + // but that would actually be a breaking behavior change because it would call + // `spawn_blocking()` on the current thread Box::pin(async move { - let canonical = self.canonicalize()?; - let migrations_with_paths = - crate::rt::spawn_blocking(move || resolve_blocking(&canonical)).await?; + crate::rt::spawn_blocking(move || { + let migrations_with_paths = resolve_blocking(&self)?; - Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + }) + .await }) } } -impl MigrationSource<'static> for PathBuf { - fn resolve(self) -> BoxFuture<'static, Result, BoxDynError>> { - Box::pin(async move { self.as_path().resolve().await }) +/// A [`MigrationSource`] implementation with configurable resolution. +/// +/// `S` may be `PathBuf`, `&Path` or any type that implements `Into`. +/// +/// See [`ResolveConfig`] for details. +#[derive(Debug)] +pub struct ResolveWith(pub S, pub ResolveConfig); + +impl<'s, S: Debug + Into + Send + 's> MigrationSource<'s> for ResolveWith { + fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> { + Box::pin(async move { + let path = self.0.into(); + let config = self.1; + + let migrations_with_paths = + crate::rt::spawn_blocking(move || resolve_blocking_with_config(&path, &config)) + .await?; + + Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect()) + }) } } @@ -52,11 +82,87 @@ pub struct ResolveError { source: Option, } +/// Configuration for migration resolution using [`ResolveWith`]. +#[derive(Debug, Default)] +pub struct ResolveConfig { + ignored_chars: BTreeSet, +} + +impl ResolveConfig { + /// Return a default, empty configuration. + pub fn new() -> Self { + ResolveConfig { + ignored_chars: BTreeSet::new(), + } + } + + /// Ignore a character when hashing migrations. + /// + /// The migration SQL string itself will still contain the character, + /// but it will not be included when calculating the checksum. + /// + /// This can be used to ignore whitespace characters so changing formatting + /// does not change the checksum. + /// + /// Adding the same `char` more than once is a no-op. + /// + /// ### Note: Changes Migration Checksum + /// This will change the checksum of resolved migrations, + /// which may cause problems with existing deployments. + /// + /// **Use at your own risk.** + pub fn ignore_char(&mut self, c: char) -> &mut Self { + self.ignored_chars.insert(c); + self + } + + /// Ignore one or more characters when hashing migrations. + /// + /// The migration SQL string itself will still contain these characters, + /// but they will not be included when calculating the checksum. + /// + /// This can be used to ignore whitespace characters so changing formatting + /// does not change the checksum. + /// + /// Adding the same `char` more than once is a no-op. + /// + /// ### Note: Changes Migration Checksum + /// This will change the checksum of resolved migrations, + /// which may cause problems with existing deployments. + /// + /// **Use at your own risk.** + pub fn ignore_chars(&mut self, chars: impl IntoIterator) -> &mut Self { + self.ignored_chars.extend(chars); + self + } + + /// Iterate over the set of ignored characters. + /// + /// Duplicate `char`s are not included. + pub fn ignored_chars(&self) -> impl Iterator + '_ { + self.ignored_chars.iter().copied() + } +} + // FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly // since it's `#[non_exhaustive]`. +#[doc(hidden)] pub fn resolve_blocking(path: &Path) -> Result, ResolveError> { - let s = fs::read_dir(path).map_err(|e| ResolveError { - message: format!("error reading migration directory {}: {e}", path.display()), + resolve_blocking_with_config(path, &ResolveConfig::new()) +} + +#[doc(hidden)] +pub fn resolve_blocking_with_config( + path: &Path, + config: &ResolveConfig, +) -> Result, ResolveError> { + let path = path.canonicalize().map_err(|e| ResolveError { + message: format!("error canonicalizing path {}", path.display()), + source: Some(e), + })?; + + let s = fs::read_dir(&path).map_err(|e| ResolveError { + message: format!("error reading migration directory {}", path.display()), source: Some(e), })?; @@ -65,7 +171,7 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv for res in s { let entry = res.map_err(|e| ResolveError { message: format!( - "error reading contents of migration directory {}: {e}", + "error reading contents of migration directory {}", path.display() ), source: Some(e), @@ -126,12 +232,15 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv // opt-out of migration transaction let no_tx = sql.starts_with("-- no-transaction"); + let checksum = checksum_with(&sql, &config.ignored_chars); + migrations.push(( - Migration::new( + Migration::with_checksum( version, Cow::Owned(description), migration_type, Cow::Owned(sql), + checksum.into(), no_tx, ), entry_path, @@ -143,3 +252,41 @@ pub fn resolve_blocking(path: &Path) -> Result, Resolv Ok(migrations) } + +fn checksum_with(sql: &str, ignored_chars: &BTreeSet) -> Vec { + if ignored_chars.is_empty() { + // This is going to be much faster because it doesn't have to UTF-8 decode `sql`. + return migration::checksum(sql); + } + + migration::checksum_fragments(sql.split(|c| ignored_chars.contains(&c))) +} + +#[test] +fn checksum_with_ignored_chars() { + // Ensure that `checksum_with` returns the same digest for a given set of ignored chars + // as the equivalent string with the characters removed. + let ignored_chars = [' ', '\t', '\r', '\n']; + + // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` + let sql = "\ + create table comment (\r\n\ + \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ + \tpost_id uuid not null references post(post_id),\r\n\ + \tuser_id uuid not null references \"user\"(user_id),\r\n\ + \tcontent text not null,\r\n\ + \tcreated_at timestamptz not null default now()\r\n\ + );\r\n\ + \r\n\ + create index on comment(post_id, created_at);\r\n\ + "; + + let stripped_sql = sql.replace(&ignored_chars[..], ""); + + let ignored_chars = BTreeSet::from(ignored_chars); + + let digest_ignored = checksum_with(sql, &ignored_chars); + let digest_stripped = migration::checksum(&stripped_sql); + + assert_eq!(digest_ignored, digest_stripped); +} From 9f34fc8dd21c16b108d52b50c7da77bb35336ecf Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 9 Sep 2024 00:49:20 -0700 Subject: [PATCH 03/30] chore: test ignored_chars with `U+FEFF` (ZWNBSP/BOM) https://en.wikipedia.org/wiki/Byte_order_mark --- sqlx-core/src/config/reference.toml | 6 ++++- sqlx-core/src/config/tests.rs | 2 +- sqlx-core/src/migrate/migration.rs | 2 +- sqlx-core/src/migrate/source.rs | 34 +++++++++++++++++------------ 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sqlx-core/src/config/reference.toml b/sqlx-core/src/config/reference.toml index fae92f3422..6d52f615eb 100644 --- a/sqlx-core/src/config/reference.toml +++ b/sqlx-core/src/config/reference.toml @@ -155,7 +155,11 @@ migrations_dir = "foo/migrations" # ignored_chars = ["\r"] # Ignore common whitespace characters (beware syntatically significant whitespace!) -ignored_chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF +# Space, tab, CR, LF, zero-width non-breaking space (U+FEFF) +# +# U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded, +# where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark +ignored_chars = [" ", "\t", "\r", "\n", "\uFEFF"] # Specify reversible migrations by default (for `sqlx migrate create`). # diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs index bf042069a2..521e7074b3 100644 --- a/sqlx-core/src/config/tests.rs +++ b/sqlx-core/src/config/tests.rs @@ -81,7 +81,7 @@ fn assert_migrate_config(config: &config::migrate::Config) { assert_eq!(config.table_name.as_deref(), Some("foo._sqlx_migrations")); assert_eq!(config.migrations_dir.as_deref(), Some("foo/migrations")); - let ignored_chars = BTreeSet::from([' ', '\t', '\r', '\n']); + let ignored_chars = BTreeSet::from([' ', '\t', '\r', '\n', '\u{FEFF}']); assert_eq!(config.ignored_chars, ignored_chars); diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index df7a11d78b..1f1175ce58 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -76,7 +76,7 @@ pub fn checksum_fragments<'a>(fragments: impl Iterator) -> Vec for PathBuf { } /// A [`MigrationSource`] implementation with configurable resolution. -/// +/// /// `S` may be `PathBuf`, `&Path` or any type that implements `Into`. -/// +/// /// See [`ResolveConfig`] for details. #[derive(Debug)] pub struct ResolveWith(pub S, pub ResolveConfig); @@ -97,20 +97,20 @@ impl ResolveConfig { } /// Ignore a character when hashing migrations. - /// + /// /// The migration SQL string itself will still contain the character, /// but it will not be included when calculating the checksum. - /// + /// /// This can be used to ignore whitespace characters so changing formatting /// does not change the checksum. - /// + /// /// Adding the same `char` more than once is a no-op. - /// + /// /// ### Note: Changes Migration Checksum - /// This will change the checksum of resolved migrations, + /// This will change the checksum of resolved migrations, /// which may cause problems with existing deployments. /// - /// **Use at your own risk.** + /// **Use at your own risk.** pub fn ignore_char(&mut self, c: char) -> &mut Self { self.ignored_chars.insert(c); self @@ -123,21 +123,21 @@ impl ResolveConfig { /// /// This can be used to ignore whitespace characters so changing formatting /// does not change the checksum. - /// + /// /// Adding the same `char` more than once is a no-op. /// /// ### Note: Changes Migration Checksum - /// This will change the checksum of resolved migrations, + /// This will change the checksum of resolved migrations, /// which may cause problems with existing deployments. /// - /// **Use at your own risk.** + /// **Use at your own risk.** pub fn ignore_chars(&mut self, chars: impl IntoIterator) -> &mut Self { self.ignored_chars.extend(chars); self } /// Iterate over the set of ignored characters. - /// + /// /// Duplicate `char`s are not included. pub fn ignored_chars(&self) -> impl Iterator + '_ { self.ignored_chars.iter().copied() @@ -266,11 +266,17 @@ fn checksum_with(sql: &str, ignored_chars: &BTreeSet) -> Vec { fn checksum_with_ignored_chars() { // Ensure that `checksum_with` returns the same digest for a given set of ignored chars // as the equivalent string with the characters removed. - let ignored_chars = [' ', '\t', '\r', '\n']; + let ignored_chars = [ + ' ', '\t', '\r', '\n', + // Zero-width non-breaking space (ZWNBSP), often added as a magic-number at the beginning + // of UTF-8 encoded files as a byte-order mark (BOM): + // https://en.wikipedia.org/wiki/Byte_order_mark + '\u{FEFF}', + ]; // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` let sql = "\ - create table comment (\r\n\ + \u{FEFF}create table comment (\r\n\ \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ \tpost_id uuid not null references post(post_id),\r\n\ \tuser_id uuid not null references \"user\"(user_id),\r\n\ From e775d2a3eb1e6a70ddc78d12e49b0854edf1dd99 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 18 Sep 2024 01:54:22 -0700 Subject: [PATCH 04/30] refactor: make `Config` always compiled simplifies usage while still making parsing optional for less generated code --- Cargo.toml | 10 ++-- sqlx-cli/Cargo.toml | 5 +- sqlx-core/Cargo.toml | 12 ++-- sqlx-core/src/config/common.rs | 9 ++- sqlx-core/src/config/macros.rs | 12 ++-- sqlx-core/src/config/migrate.rs | 20 +++++-- sqlx-core/src/config/mod.rs | 95 ++++++++++++++++++++----------- sqlx-core/src/config/tests.rs | 2 - sqlx-core/src/lib.rs | 1 - sqlx-macros-core/Cargo.toml | 4 +- sqlx-macros-core/src/query/mod.rs | 9 ++- sqlx-macros/Cargo.toml | 3 +- 12 files changed, 117 insertions(+), 65 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 72a9d01c28..cf3352cedd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,16 +54,14 @@ features = ["all-databases", "_unstable-all-types", "_unstable-doc"] rustdoc-args = ["--cfg", "docsrs"] [features] -default = ["any", "macros", "migrate", "json", "config-all"] +default = ["any", "macros", "migrate", "json", "sqlx-toml"] derive = ["sqlx-macros/derive"] macros = ["derive", "sqlx-macros/macros"] migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] -# Enable parsing of `sqlx.toml` for configuring macros, migrations, or both. -config-macros = ["sqlx-macros?/config-macros"] -config-migrate = ["sqlx-macros?/config-migrate"] -config-all = ["config-macros", "config-migrate"] +# Enable parsing of `sqlx.toml` for configuring macros and migrations. +sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-macros?/sqlx-toml"] # intended mainly for CI and docs all-databases = ["mysql", "sqlite", "postgres", "any"] @@ -79,7 +77,7 @@ _unstable-all-types = [ "bit-vec", ] # Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`). -_unstable-doc = ["config-all", "sqlx-core/_unstable-doc"] +_unstable-doc = [] # Base runtime features without TLS runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 0b047ab136..4ece226338 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -49,7 +49,8 @@ filetime = "0.2" backoff = { version = "0.4.0", features = ["futures", "tokio"] } [features] -default = ["postgres", "sqlite", "mysql", "native-tls", "completions"] +default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"] + rustls = ["sqlx/runtime-tokio-rustls"] native-tls = ["sqlx/runtime-tokio-native-tls"] @@ -64,6 +65,8 @@ openssl-vendored = ["openssl/vendored"] completions = ["dep:clap_complete"] +sqlx-toml = ["sqlx/sqlx-toml"] + [dev-dependencies] assert_cmd = "2.0.11" tempfile = "3.10.1" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index f70adde55e..ee6e344efa 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -12,7 +12,7 @@ features = ["offline"] [features] default = [] -migrate = ["sha2", "crc", "config-migrate"] +migrate = ["sha2", "crc"] any = [] @@ -31,11 +31,13 @@ _tls-none = [] # support offline/decoupled building (enables serialization of `Describe`) offline = ["serde", "either/serde"] -config = ["serde", "toml/parse"] -config-macros = ["config"] -config-migrate = ["config"] +# Enable parsing of `sqlx.toml`. +# For simplicity, the `config` module is always enabled, +# but disabling this disables the `serde` derives and the `toml` crate, +# which is a good bit less code to compile if the feature isn't being used. +sqlx-toml = ["serde", "toml/parse"] -_unstable-doc = ["config-macros", "config-migrate"] +_unstable-doc = ["sqlx-toml"] [dependencies] # Runtimes diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs index 8c774fc60f..1468f24abd 100644 --- a/sqlx-core/src/config/common.rs +++ b/sqlx-core/src/config/common.rs @@ -1,5 +1,6 @@ /// Configuration shared by multiple components. -#[derive(Debug, Default, serde::Deserialize)] +#[derive(Debug, Default)] +#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize))] pub struct Config { /// Override the database URL environment variable. /// @@ -36,3 +37,9 @@ pub struct Config { /// and the ones used in `bar` will use `BAR_DATABASE_URL`. pub database_url_var: Option, } + +impl Config { + pub fn database_url_var(&self) -> &str { + self.database_url_var.as_deref().unwrap_or("DATABASE_URL") + } +} \ No newline at end of file diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs index 5edd30dc15..142f059da4 100644 --- a/sqlx-core/src/config/macros.rs +++ b/sqlx-core/src/config/macros.rs @@ -1,8 +1,8 @@ use std::collections::BTreeMap; /// Configuration for the `query!()` family of macros. -#[derive(Debug, Default, serde::Deserialize)] -#[serde(default)] +#[derive(Debug, Default)] +#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize), serde(default))] pub struct Config { /// Specify the crate to use for mapping date/time types to Rust. /// @@ -235,8 +235,12 @@ pub struct Config { } /// The crate to use for mapping date/time types to Rust. -#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] pub enum DateTimeCrate { /// Use whichever crate is enabled (`time` then `chrono`). #[default] diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index 5878f9a24f..efc03a0155 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -12,8 +12,8 @@ use std::collections::BTreeSet; /// if the proper precautions are not taken. /// /// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. -#[derive(Debug, Default, serde::Deserialize)] -#[serde(default)] +#[derive(Debug, Default)] +#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize), serde(default))] pub struct Config { /// Override the name of the table used to track executed migrations. /// @@ -118,8 +118,12 @@ pub struct Config { } /// The default type of migration that `sqlx migrate create` should create by default. -#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] pub enum DefaultMigrationType { /// Create the same migration type as that of the latest existing migration, /// or `Simple` otherwise. @@ -134,8 +138,12 @@ pub enum DefaultMigrationType { } /// The default scheme that `sqlx migrate create` should use for version integers. -#[derive(Debug, Default, PartialEq, Eq, serde::Deserialize)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] pub enum DefaultVersioning { /// Infer the versioning scheme from existing migrations: /// diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index 979477241f..3bbde5c2f1 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -7,6 +7,7 @@ //! //! See the [reference][`_reference`] for the full `sqlx.toml` file. +use std::error::Error; use std::fmt::Debug; use std::io; use std::path::{Path, PathBuf}; @@ -23,13 +24,11 @@ pub mod common; /// Configuration for the `query!()` family of macros. /// /// See [`macros::Config`] for details. -#[cfg(feature = "config-macros")] pub mod macros; /// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. /// /// See [`migrate::Config`] for details. -#[cfg(feature = "config-migrate")] pub mod migrate; /// Reference for `sqlx.toml` files @@ -41,11 +40,12 @@ pub mod migrate; /// ``` pub mod _reference {} -#[cfg(test)] +#[cfg(all(test, feature = "sqlx-toml"))] mod tests; /// The parsed structure of a `sqlx.toml` file. -#[derive(Debug, Default, serde::Deserialize)] +#[derive(Debug, Default)] +#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize))] pub struct Config { /// Configuration shared by multiple components. /// @@ -55,21 +55,11 @@ pub struct Config { /// Configuration for the `query!()` family of macros. /// /// See [`macros::Config`] for details. - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "config-all", feature = "config-macros"))) - )] - #[cfg(feature = "config-macros")] pub macros: macros::Config, /// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`. /// /// See [`migrate::Config`] for details. - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "config-all", feature = "config-migrate"))) - )] - #[cfg(feature = "config-migrate")] pub migrate: migrate::Config, } @@ -90,13 +80,17 @@ pub enum ConfigError { std::env::VarError, ), + /// No configuration file was found. Not necessarily fatal. + #[error("config file {path:?} not found")] + NotFound { + path: PathBuf, + }, + /// An I/O error occurred while attempting to read the config file at `path`. /// - /// This includes [`io::ErrorKind::NotFound`]. - /// - /// [`Self::not_found_path()`] will return the path if the file was not found. + /// If the error is [`io::ErrorKind::NotFound`], [`Self::NotFound`] is returned instead. #[error("error reading config file {path:?}")] - Read { + Io { path: PathBuf, #[source] error: io::Error, @@ -105,22 +99,41 @@ pub enum ConfigError { /// An error in the TOML was encountered while parsing the config file at `path`. /// /// The error gives line numbers and context when printed with `Display`/`ToString`. + /// + /// Only returned if the `sqlx-toml` feature is enabled. #[error("error parsing config file {path:?}")] Parse { path: PathBuf, + /// Type-erased [`toml::de::Error`]. #[source] - error: toml::de::Error, + error: Box, + }, + + /// A `sqlx.toml` file was found or specified, but the `sqlx-toml` feature is not enabled. + #[error("SQLx found config file at {path:?} but the `sqlx-toml` feature was not enabled")] + ParseDisabled { + path: PathBuf }, } impl ConfigError { + /// Create a [`ConfigError`] from a [`std::io::Error`]. + /// + /// Maps to either `NotFound` or `Io`. + pub fn from_io(path: PathBuf, error: io::Error) -> Self { + if error.kind() == io::ErrorKind::NotFound { + Self::NotFound { path } + } else { + Self::Io { path, error } + } + } + /// If this error means the file was not found, return the path that was attempted. pub fn not_found_path(&self) -> Option<&Path> { - match self { - ConfigError::Read { path, error } if error.kind() == io::ErrorKind::NotFound => { - Some(path) - } - _ => None, + if let Self::NotFound { path } = self { + Some(path) + } else { + None } } } @@ -140,14 +153,22 @@ impl Config { /// If the file exists but an unrecoverable error was encountered while parsing it. pub fn from_crate() -> &'static Self { Self::try_from_crate().unwrap_or_else(|e| { - if let Some(path) = e.not_found_path() { - // Non-fatal - tracing::debug!("Not reading config, file {path:?} not found (error: {e})"); - CACHE.get_or_init(Config::default) - } else { + match e { + ConfigError::NotFound { path } => { + // Non-fatal + tracing::debug!("Not reading config, file {path:?} not found"); + CACHE.get_or_init(Config::default) + } + // FATAL ERRORS BELOW: // In the case of migrations, // we can't proceed with defaults as they may be completely wrong. - panic!("failed to read sqlx config: {e}") + e @ ConfigError::ParseDisabled { .. } => { + // Only returned if the file exists but the feature is not enabled. + panic!("{e}") + } + e => { + panic!("failed to read sqlx config: {e}") + } } }) } @@ -188,12 +209,13 @@ impl Config { }) } + #[cfg(feature = "sqlx-toml")] fn read_from(path: PathBuf) -> Result { // The `toml` crate doesn't provide an incremental reader. let toml_s = match std::fs::read_to_string(&path) { Ok(toml) => toml, Err(error) => { - return Err(ConfigError::Read { path, error }); + return Err(ConfigError::from_io(path, error)); } }; @@ -201,6 +223,15 @@ impl Config { // Motivation: https://github.com/toml-rs/toml/issues/761 tracing::debug!("read config TOML from {path:?}:\n{toml_s}"); - toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { path, error }) + toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { path, error: Box::new(error) }) + } + + #[cfg(not(feature = "sqlx-toml"))] + fn read_from(path: PathBuf) -> Result { + match path.try_exists() { + Ok(true) => Err(ConfigError::ParseDisabled { path }), + Ok(false) => Err(ConfigError::NotFound { path }), + Err(e) => Err(ConfigError::from_io(path, e)) + } } } diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs index 521e7074b3..e5033bb459 100644 --- a/sqlx-core/src/config/tests.rs +++ b/sqlx-core/src/config/tests.rs @@ -20,7 +20,6 @@ fn assert_common_config(config: &config::common::Config) { assert_eq!(config.database_url_var.as_deref(), Some("FOO_DATABASE_URL")); } -#[cfg(feature = "config-macros")] fn assert_macros_config(config: &config::macros::Config) { use config::macros::*; @@ -74,7 +73,6 @@ fn assert_macros_config(config: &config::macros::Config) { ); } -#[cfg(feature = "config-migrate")] fn assert_migrate_config(config: &config::migrate::Config) { use config::migrate::*; diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 8b831ecaff..09f2900ba8 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -91,7 +91,6 @@ pub mod any; #[cfg(feature = "migrate")] pub mod testing; -#[cfg(feature = "config")] pub mod config; pub use error::{Error, Result}; diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 46b7dfbf93..ad1a8e18ed 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -26,9 +26,7 @@ derive = [] macros = [] migrate = ["sqlx-core/migrate"] -config = ["sqlx-core/config"] -config-macros = ["config", "sqlx-core/config-macros"] -config-migrate = ["config", "sqlx-core/config-migrate"] +sqlx-toml = ["sqlx-core/sqlx-toml"] # database mysql = ["sqlx-mysql"] diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index 09acff9bd2..190d272d14 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -16,6 +16,7 @@ use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; use either::Either; use url::Url; +use sqlx_core::config::Config; mod args; mod data; @@ -138,8 +139,12 @@ static METADATA: Lazy = Lazy::new(|| { let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - - let database_url = env("DATABASE_URL").ok(); + + let var_name = Config::from_crate() + .common + .database_url_var(); + + let database_url = env(var_name).ok(); Metadata { manifest_dir, diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 1d1b0bcd48..6792af6ecc 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -27,8 +27,7 @@ derive = ["sqlx-macros-core/derive"] macros = ["sqlx-macros-core/macros"] migrate = ["sqlx-macros-core/migrate"] -config-macros = ["sqlx-macros-core/config-macros"] -config-migrate = ["sqlx-macros-core/config-migrate"] +sqlx-toml = ["sqlx-macros-core/sqlx-toml"] # database mysql = ["sqlx-macros-core/mysql"] From bf90a477a1328c654944783b80c30c41c4b04ecf Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 18 Sep 2024 01:55:59 -0700 Subject: [PATCH 05/30] refactor: add origin information to `Column` --- sqlx-core/src/column.rs | 54 ++++++++++++++++++++ sqlx-mysql/src/column.rs | 7 +++ sqlx-mysql/src/connection/executor.rs | 21 ++++++++ sqlx-mysql/src/protocol/text/column.rs | 16 ++++-- sqlx-postgres/src/column.rs | 9 ++++ sqlx-postgres/src/connection/describe.rs | 59 +++++++++++++++++++++ sqlx-postgres/src/connection/mod.rs | 8 +++ sqlx-sqlite/src/column.rs | 7 +++ sqlx-sqlite/src/connection/describe.rs | 3 ++ sqlx-sqlite/src/statement/handle.rs | 65 +++++++++++++++++++++++- src/lib.rs | 1 + 11 files changed, 243 insertions(+), 7 deletions(-) diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819ed6..7483375765 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -2,6 +2,7 @@ use crate::database::Database; use crate::error::Error; use std::fmt::Debug; +use std::sync::Arc; pub trait Column: 'static + Send + Sync + Debug { type Database: Database; @@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug { /// Gets the type information for the column. fn type_info(&self) -> &::TypeInfo; + + /// If this column comes from a table, return the table and original column name. + /// + /// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression + /// or else the source table could not be determined. + /// + /// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information, + /// or has not overridden this method. + // This method returns an owned value instead of a reference, + // to give the implementor more flexibility. + fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown } +} + +/// A [`Column`] that originates from a table. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct TableColumn { + /// The name of the table (optionally schema-qualified) that the column comes from. + pub table: Arc, + /// The original name of the column. + pub name: Arc, +} + +/// The possible statuses for our knowledge of the origin of a [`Column`]. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub enum ColumnOrigin { + /// The column is known to originate from a table. + /// + /// Included is the table name and original column name. + Table(TableColumn), + /// The column originates from an expression, or else its origin could not be determined. + Expression, + /// The database driver does not know the column origin at this time. + /// + /// This may happen if: + /// * The connection is in the middle of executing a query, + /// and cannot query the catalog to fetch this information. + /// * The connection does not have access to the database catalog. + /// * The implementation of [`Column`] did not override [`Column::origin()`]. + #[default] + Unknown, +} + +impl ColumnOrigin { + /// Returns the true column origin, if known. + pub fn table_column(&self) -> Option<&TableColumn> { + if let Self::Table(table_column) = self { + Some(table_column) + } else { + None + } + } } /// A type that can be used to index into a [`Row`] or [`Statement`]. diff --git a/sqlx-mysql/src/column.rs b/sqlx-mysql/src/column.rs index 1bb841b9a1..457cf991d3 100644 --- a/sqlx-mysql/src/column.rs +++ b/sqlx-mysql/src/column.rs @@ -10,6 +10,9 @@ pub struct MySqlColumn { pub(crate) name: UStr, pub(crate) type_info: MySqlTypeInfo, + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) flags: Option, } @@ -28,4 +31,8 @@ impl Column for MySqlColumn { fn type_info(&self) -> &MySqlTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 07c7979b08..6baad5ccab 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -23,6 +23,7 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::{borrow::Cow, sync::Arc}; +use sqlx_core::column::{ColumnOrigin, TableColumn}; impl MySqlConnection { async fn prepare_statement<'c>( @@ -382,11 +383,30 @@ async fn recv_result_columns( fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result { // if the alias is empty, use the alias // only then use the name + let column_name = def.name()?; + let name = match (def.name()?, def.alias()?) { (_, alias) if !alias.is_empty() => UStr::new(alias), (name, _) => UStr::new(name), }; + let table = def.table()?; + + let origin = if table.is_empty() { + ColumnOrigin::Expression + } else { + let schema = def.schema()?; + + ColumnOrigin::Table(TableColumn { + table: if !schema.is_empty() { + format!("{schema}.{table}").into() + } else { + table.into() + }, + name: column_name.into(), + }) + }; + let type_info = MySqlTypeInfo::from_column(def); Ok(MySqlColumn { @@ -394,6 +414,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result Result<&str, Error> { + str::from_utf8(&self.schema).map_err(Error::protocol) + } + + pub(crate) fn table(&self) -> Result<&str, Error> { + str::from_utf8(&self.table).map_err(Error::protocol) + } + pub(crate) fn name(&self) -> Result<&str, Error> { - from_utf8(&self.name).map_err(Error::protocol) + str::from_utf8(&self.name).map_err(Error::protocol) } pub(crate) fn alias(&self) -> Result<&str, Error> { - from_utf8(&self.alias).map_err(Error::protocol) + str::from_utf8(&self.alias).map_err(Error::protocol) } } diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index a838c27b75..bd08e27db0 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -2,6 +2,7 @@ use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; pub(crate) use sqlx_core::column::{Column, ColumnIndex}; +use sqlx_core::column::ColumnOrigin; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] @@ -9,6 +10,10 @@ pub struct PgColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: PgTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] @@ -51,4 +56,8 @@ impl Column for PgColumn { fn type_info(&self) -> &PgTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c56c..53affe5dc3 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,3 +1,4 @@ +use std::collections::btree_map; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; @@ -13,6 +14,9 @@ use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; +use sqlx_core::column::{ColumnOrigin, TableColumn}; +use sqlx_core::hash_map; +use crate::connection::TableColumns; /// Describes the type of the `pg_type.typtype` column /// @@ -121,6 +125,12 @@ impl PgConnection { let type_info = self .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) .await?; + + let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { + self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await? + } else { + ColumnOrigin::Expression + }; let column = PgColumn { ordinal: index, @@ -128,6 +138,7 @@ impl PgConnection { type_info, relation_id: field.relation_id, relation_attribute_no: field.relation_attribute_no, + origin, }; columns.push(column); @@ -189,6 +200,54 @@ impl PgConnection { Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) } } + + async fn maybe_fetch_column_origin( + &mut self, + relation_id: Oid, + attribute_no: i16, + should_fetch: bool, + ) -> Result { + let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) { + hash_map::Entry::Occupied(table_columns) => { + table_columns.into_mut() + }, + hash_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let table_name: String = query_scalar("SELECT $1::oid::regclass::text") + .bind(relation_id) + .fetch_one(&mut *self) + .await?; + + vacant.insert(TableColumns { + table_name: table_name.into(), + columns: Default::default(), + }) + } + }; + + let column_name = match table_columns.columns.entry(attribute_no) { + btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()), + btree_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let column_name: String = query_scalar( + "SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2" + ) + .bind(relation_id) + .bind(attribute_no) + .fetch_one(&mut *self) + .await?; + + Arc::clone(vacant.insert(column_name.into())) + } + }; + + Ok(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name + })) + } async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result { let (name, typ_type, category, relation_id, element, base_type): ( diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e53d..3cb9ecaf62 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -61,6 +62,7 @@ pub struct PgConnectionInner { cache_type_info: HashMap, cache_type_oid: HashMap, cache_elem_type_to_array: HashMap, + cache_table_to_column_names: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, @@ -72,6 +74,12 @@ pub struct PgConnectionInner { log_settings: LogSettings, } +pub(crate) struct TableColumns { + table_name: Arc, + /// Attribute number -> name. + columns: BTreeMap>, +} + impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { diff --git a/sqlx-sqlite/src/column.rs b/sqlx-sqlite/src/column.rs index 00b3bc360c..390f3687fb 100644 --- a/sqlx-sqlite/src/column.rs +++ b/sqlx-sqlite/src/column.rs @@ -9,6 +9,9 @@ pub struct SqliteColumn { pub(crate) name: UStr, pub(crate) ordinal: usize, pub(crate) type_info: SqliteTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin } impl Column for SqliteColumn { @@ -25,4 +28,8 @@ impl Column for SqliteColumn { fn type_info(&self) -> &SqliteTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 0f4da33ccc..9ba9f8c3b1 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result Result); unsafe impl Send for StatementHandle {} +// Most of the getters below allocate internally, and unsynchronized access is undefined. +// unsafe impl !Sync for StatementHandle {} + macro_rules! expect_ret_valid { ($fn_name:ident($($args:tt)*)) => {{ let val = $fn_name($($args)*); @@ -110,6 +113,64 @@ impl StatementHandle { } } + pub(crate) fn column_origin(&self, index: usize) -> ColumnOrigin { + if let Some((table, name)) = + self.column_table_name(index).zip(self.column_origin_name(index)) + { + let table: Arc = self + .column_db_name(index) + .filter(|&db| db != "main") + .map_or_else( + || table.into(), + // TODO: check that SQLite returns the names properly quoted if necessary + |db| format!("{db}.{table}").into(), + ); + + ColumnOrigin::Table(TableColumn { + table, + name: name.into() + }) + } else { + ColumnOrigin::Expression + } + } + + fn column_db_name(&self, index: usize) -> Option<&str> { + unsafe { + let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index)); + + if !db_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes())) + } else { + None + } + } + } + + fn column_table_name(&self, index: usize) -> Option<&str> { + unsafe { + let table_name = sqlite3_column_table_name(self.0.as_ptr(), check_col_idx!(index)); + + if !table_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(table_name).to_bytes())) + } else { + None + } + } + } + + fn column_origin_name(&self, index: usize) -> Option<&str> { + unsafe { + let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), check_col_idx!(index)); + + if !origin_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(origin_name).to_bytes())) + } else { + None + } + } + } + pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo { SqliteTypeInfo(DataType::from_code(self.column_type(index))) } diff --git a/src/lib.rs b/src/lib.rs index 19142f6666..a357753b96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; +pub use sqlx_core::column::ColumnOrigin; pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe; From 5cb3de38b9706091fea0b3724a892f7bbf3a4c2d Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 18 Sep 2024 18:17:43 -0700 Subject: [PATCH 06/30] feat(macros): implement `type_override` and `column_override` from `sqlx.toml` --- sqlx-macros-core/src/query/output.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sqlx-macros-core/src/query/output.rs b/sqlx-macros-core/src/query/output.rs index 5e7cc5058d..d9dc79a366 100644 --- a/sqlx-macros-core/src/query/output.rs +++ b/sqlx-macros-core/src/query/output.rs @@ -2,7 +2,7 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::Type; -use sqlx_core::column::Column; +use sqlx_core::column::{Column, ColumnOrigin}; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; @@ -12,6 +12,8 @@ use sqlx_core::type_checking::TypeChecking; use std::fmt::{self, Display, Formatter}; use syn::parse::{Parse, ParseStream}; use syn::Token; +use sqlx_core::config::Config; +use sqlx_core::type_info::TypeInfo; pub struct RustColumn { pub(super) ident: Ident, @@ -229,8 +231,24 @@ pub fn quote_query_scalar( } fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { + if let ColumnOrigin::Table(origin) = column.origin() { + if let Some(column_override) = Config::from_crate() + .macros + .column_override(&origin.table, &origin.name) + { + return column_override.parse().unwrap(); + } + } + let type_info = column.type_info(); + if let Some(type_override) = Config::from_crate() + .macros + .type_override(type_info.name()) + { + return type_override.parse().unwrap(); + } + ::return_type_for_id(type_info).map_or_else( || { let message = From 8604b51ae3fcf46eb86e461c13428c2e3a7bebe7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 19 Sep 2024 19:23:03 -0700 Subject: [PATCH 07/30] refactor(sqlx.toml): make all keys kebab-case, create `macros.preferred-crates` --- sqlx-core/src/config/common.rs | 14 ++- sqlx-core/src/config/macros.rs | 179 +++++++++++++++++++++------- sqlx-core/src/config/migrate.rs | 47 +++++--- sqlx-core/src/config/mod.rs | 6 +- sqlx-core/src/config/reference.toml | 42 ++++--- sqlx-core/src/config/tests.rs | 10 +- 6 files changed, 206 insertions(+), 92 deletions(-) diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs index 1468f24abd..c09ed80d7f 100644 --- a/sqlx-core/src/config/common.rs +++ b/sqlx-core/src/config/common.rs @@ -1,6 +1,10 @@ /// Configuration shared by multiple components. #[derive(Debug, Default)] -#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize))] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] pub struct Config { /// Override the database URL environment variable. /// @@ -17,14 +21,14 @@ pub struct Config { /// /// #### `foo/sqlx.toml` /// ```toml - /// [macros] - /// database_url_var = "FOO_DATABASE_URL" + /// [common] + /// database-url-var = "FOO_DATABASE_URL" /// ``` /// /// #### `bar/sqlx.toml` /// ```toml - /// [macros] - /// database_url_var = "BAR_DATABASE_URL" + /// [common] + /// database-url-var = "BAR_DATABASE_URL" /// ``` /// /// #### `.env` diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs index 142f059da4..9f4cf4524f 100644 --- a/sqlx-core/src/config/macros.rs +++ b/sqlx-core/src/config/macros.rs @@ -2,33 +2,16 @@ use std::collections::BTreeMap; /// Configuration for the `query!()` family of macros. #[derive(Debug, Default)] -#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize), serde(default))] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] pub struct Config { - /// Specify the crate to use for mapping date/time types to Rust. - /// - /// The default behavior is to use whatever crate is enabled, - /// [`chrono`] or [`time`] (the latter takes precedent). - /// - /// [`chrono`]: crate::types::chrono - /// [`time`]: crate::types::time - /// - /// Example: Always Use Chrono - /// ------- - /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable - /// the `time` feature of SQLx which will force it on for all crates using SQLx, - /// which will result in problems if your crate wants to use types from [`chrono`]. - /// - /// You can use the type override syntax (see `sqlx::query!` for details), - /// or you can force an override globally by setting this option. - /// - /// #### `sqlx.toml` - /// ```toml - /// [macros] - /// datetime_crate = "chrono" - /// ``` - /// - /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification - pub datetime_crate: DateTimeCrate, + /// Specify which crates' types to use when types from multiple crates apply. + /// + /// See [`PreferredCrates`] for details. + pub preferred_crates: PreferredCrates, /// Specify global overrides for mapping SQL type names to Rust type names. /// @@ -78,7 +61,7 @@ pub struct Config { /// /// #### `sqlx.toml` /// ```toml - /// [macros.type_overrides] + /// [macros.type-overrides] /// # Override a built-in type /// 'uuid' = "crate::types::MyUuid" /// @@ -115,7 +98,7 @@ pub struct Config { /// /// #### `sqlx.toml` /// ```toml - /// [macros.type_overrides] + /// [macros.type-overrides] /// # Map SQL type `foo` to `crate::types::Foo` /// 'foo' = "crate::types::Foo" /// ``` @@ -125,7 +108,7 @@ pub struct Config { /// (See `Note` section above for details.) /// /// ```toml - /// [macros.type_overrides] + /// [macros.type-overrides] /// # Map SQL type `foo.foo` to `crate::types::Foo` /// 'foo.foo' = "crate::types::Foo" /// ``` @@ -136,7 +119,7 @@ pub struct Config { /// it must be wrapped in quotes _twice_ for SQLx to know the difference: /// /// ```toml - /// [macros.type_overrides] + /// [macros.type-overrides] /// # `"Foo"` in SQLx /// '"Foo"' = "crate::types::Foo" /// # **NOT** `"Foo"` in SQLx (parses as just `Foo`) @@ -151,7 +134,7 @@ pub struct Config { /// (See `Note` section above for details.) pub type_overrides: BTreeMap, - /// Specify per-column overrides for mapping SQL types to Rust types. + /// Specify per-table and per-column overrides for mapping SQL types to Rust types. /// /// Default type mappings are defined by the database driver. /// Refer to the `sqlx::types` module for details. @@ -206,7 +189,7 @@ pub struct Config { /// /// #### `sqlx.toml` /// ```toml - /// [macros.column_overrides.'foo'] + /// [macros.table-overrides.'foo'] /// # Map column `bar` of table `foo` to Rust type `crate::types::Foo`: /// 'bar' = "crate::types::Bar" /// @@ -218,23 +201,83 @@ pub struct Config { /// # "Bar" = "crate::types::Bar" /// /// # Table name may be quoted (note the wrapping single-quotes) - /// [macros.column_overrides.'"Foo"'] + /// [macros.table-overrides.'"Foo"'] /// 'bar' = "crate::types::Bar" /// '"Bar"' = "crate::types::Bar" /// /// # Table name may also be schema-qualified. /// # Note how the dot is inside the quotes. - /// [macros.column_overrides.'my_schema.my_table'] + /// [macros.table-overrides.'my_schema.my_table'] /// 'my_column' = "crate::types::MyType" /// /// # Quoted schema, table, and column names - /// [macros.column_overrides.'"My Schema"."My Table"'] + /// [macros.table-overrides.'"My Schema"."My Table"'] /// '"My Column"' = "crate::types::MyType" /// ``` - pub column_overrides: BTreeMap>, + pub table_overrides: BTreeMap>, +} + +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "kebab-case") +)] +pub struct PreferredCrates { + /// Specify the crate to use for mapping date/time types to Rust. + /// + /// The default behavior is to use whatever crate is enabled, + /// [`chrono`] or [`time`] (the latter takes precedent). + /// + /// [`chrono`]: crate::types::chrono + /// [`time`]: crate::types::time + /// + /// Example: Always Use Chrono + /// ------- + /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable + /// the `time` feature of SQLx which will force it on for all crates using SQLx, + /// which will result in problems if your crate wants to use types from [`chrono`]. + /// + /// You can use the type override syntax (see `sqlx::query!` for details), + /// or you can force an override globally by setting this option. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.preferred-crates] + /// date-time = "chrono" + /// ``` + /// + /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + pub date_time: DateTimeCrate, + + /// Specify the crate to use for mapping `NUMERIC` types to Rust. + /// + /// The default behavior is to use whatever crate is enabled, + /// [`bigdecimal`] or [`rust_decimal`] (the latter takes precedent). + /// + /// [`bigdecimal`]: crate::types::bigdecimal + /// [`rust_decimal`]: crate::types::rust_decimal + /// + /// Example: Always Use `bigdecimal` + /// ------- + /// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable + /// the `rust_decimal` feature of SQLx which will force it on for all crates using SQLx, + /// which will result in problems if your crate wants to use types from [`bigdecimal`]. + /// + /// You can use the type override syntax (see `sqlx::query!` for details), + /// or you can force an override globally by setting this option. + /// + /// #### `sqlx.toml` + /// ```toml + /// [macros.preferred-crates] + /// numeric = "bigdecimal" + /// ``` + /// + /// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification + pub numeric: NumericCrate, } -/// The crate to use for mapping date/time types to Rust. +/// The preferred crate to use for mapping date/time types to Rust. #[derive(Debug, Default, PartialEq, Eq)] #[cfg_attr( feature = "sqlx-toml", @@ -249,33 +292,63 @@ pub enum DateTimeCrate { /// Always use types from [`chrono`][crate::types::chrono]. /// /// ```toml - /// [macros] - /// datetime_crate = "chrono" + /// [macros.preferred-crates] + /// date-time = "chrono" /// ``` Chrono, /// Always use types from [`time`][crate::types::time]. /// /// ```toml - /// [macros] - /// datetime_crate = "time" + /// [macros.preferred-crates] + /// date-time = "time" /// ``` Time, } +/// The preferred crate to use for mapping `NUMERIC` types to Rust. +#[derive(Debug, Default, PartialEq, Eq)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(rename_all = "snake_case") +)] +pub enum NumericCrate { + /// Use whichever crate is enabled (`rust_decimal` then `bigdecimal`). + #[default] + Inferred, + + /// Always use types from [`bigdecimal`][crate::types::bigdecimal]. + /// + /// ```toml + /// [macros.preferred-crates] + /// numeric = "bigdecimal" + /// ``` + #[cfg_attr(feature = "sqlx-toml", serde(rename = "bigdecimal"))] + BigDecimal, + + /// Always use types from [`rust_decimal`][crate::types::rust_decimal]. + /// + /// ```toml + /// [macros.preferred-crates] + /// numeric = "rust_decimal" + /// ``` + RustDecimal, +} + /// A SQL type name; may optionally be schema-qualified. /// -/// See [`macros.type_overrides`][Config::type_overrides] for usages. +/// See [`macros.type-overrides`][Config::type_overrides] for usages. pub type SqlType = Box; /// A SQL table name; may optionally be schema-qualified. /// -/// See [`macros.column_overrides`][Config::column_overrides] for usages. +/// See [`macros.table-overrides`][Config::table_overrides] for usages. pub type TableName = Box; /// A column in a SQL table. /// -/// See [`macros.column_overrides`][Config::column_overrides] for usages. +/// See [`macros.table-overrides`][Config::table_overrides] for usages. pub type ColumnName = Box; /// A Rust type name or path. @@ -292,9 +365,25 @@ impl Config { /// Get the override for a given column and table name (optionally schema-qualified). pub fn column_override(&self, table: &str, column: &str) -> Option<&str> { - self.column_overrides + self.table_overrides .get(table) .and_then(|by_column| by_column.get(column)) .map(|s| &**s) } } + +impl DateTimeCrate { + /// Returns `self == Self::Inferred` + #[inline(always)] + pub fn is_inferred(&self) -> bool { + *self == Self::Inferred + } +} + +impl NumericCrate { + /// Returns `self == Self::Inferred` + #[inline(always)] + pub fn is_inferred(&self) -> bool { + *self == Self::Inferred + } +} \ No newline at end of file diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index efc03a0155..d0e55b35d8 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -13,7 +13,11 @@ use std::collections::BTreeSet; /// /// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. #[derive(Debug, Default)] -#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize), serde(default))] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] pub struct Config { /// Override the name of the table used to track executed migrations. /// @@ -35,7 +39,7 @@ pub struct Config { /// ```toml /// [migrate] /// # Put `_sqlx_migrations` in schema `foo` - /// table_name = "foo._sqlx_migrations" + /// table-name = "foo._sqlx_migrations" /// ``` pub table_name: Option>, @@ -63,7 +67,7 @@ pub struct Config { /// `sqlx.toml`: /// ```toml /// [migrate] - /// ignored_chars = ["\r"] + /// ignored-chars = ["\r"] /// ``` /// /// For projects using Git, this can also be addressed using [`.gitattributes`]: @@ -91,33 +95,44 @@ pub struct Config { /// ```toml /// [migrate] /// # Ignore common whitespace characters when hashing - /// ignored_chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF + /// ignored-chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF /// ``` // Likely lower overhead for small sets than `HashSet`. pub ignored_chars: BTreeSet, - /// Specify the default type of migration that `sqlx migrate create` should create by default. + /// Specify default options for new migrations created with `sqlx migrate add`. + pub defaults: MigrationDefaults, +} + +#[derive(Debug, Default)] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] +pub struct MigrationDefaults { + /// Specify the default type of migration that `sqlx migrate add` should create by default. /// /// ### Example: Use Reversible Migrations by Default /// `sqlx.toml`: /// ```toml - /// [migrate] - /// default_type = "reversible" + /// [migrate.defaults] + /// migration-type = "reversible" /// ``` - pub default_type: DefaultMigrationType, + pub migration_type: DefaultMigrationType, - /// Specify the default scheme that `sqlx migrate create` should use for version integers. + /// Specify the default scheme that `sqlx migrate add` should use for version integers. /// /// ### Example: Use Sequential Versioning by Default /// `sqlx.toml`: /// ```toml - /// [migrate] - /// default_versioning = "sequential" + /// [migrate.defaults] + /// migration-versioning = "sequential" /// ``` - pub default_versioning: DefaultVersioning, + pub migration_versioning: DefaultVersioning, } -/// The default type of migration that `sqlx migrate create` should create by default. +/// The default type of migration that `sqlx migrate add` should create by default. #[derive(Debug, Default, PartialEq, Eq)] #[cfg_attr( feature = "sqlx-toml", @@ -130,14 +145,14 @@ pub enum DefaultMigrationType { #[default] Inferred, - /// Create a non-reversible migration (`_.sql`). + /// Create non-reversible migrations (`_.sql`) by default. Simple, - /// Create a reversible migration (`_.up.sql` and `[...].down.sql`). + /// Create reversible migrations (`_.up.sql` and `[...].down.sql`) by default. Reversible, } -/// The default scheme that `sqlx migrate create` should use for version integers. +/// The default scheme that `sqlx migrate add` should use for version integers. #[derive(Debug, Default, PartialEq, Eq)] #[cfg_attr( feature = "sqlx-toml", diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index 3bbde5c2f1..696752a51b 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -45,7 +45,11 @@ mod tests; /// The parsed structure of a `sqlx.toml` file. #[derive(Debug, Default)] -#[cfg_attr(feature = "sqlx-toml", derive(serde::Deserialize))] +#[cfg_attr( + feature = "sqlx-toml", + derive(serde::Deserialize), + serde(default, rename_all = "kebab-case") +)] pub struct Config { /// Configuration shared by multiple components. /// diff --git a/sqlx-core/src/config/reference.toml b/sqlx-core/src/config/reference.toml index 6d52f615eb..e042824c72 100644 --- a/sqlx-core/src/config/reference.toml +++ b/sqlx-core/src/config/reference.toml @@ -13,20 +13,24 @@ # This is used by both the macros and `sqlx-cli`. # # If not specified, defaults to `DATABASE_URL` -database_url_var = "FOO_DATABASE_URL" +database-url-var = "FOO_DATABASE_URL" ############################################################################################### # Configuration for the `query!()` family of macros. [macros] + +[macros.preferred-crates] # Force the macros to use the `chrono` crate for date/time types, even if `time` is enabled. # # Defaults to "inferred": use whichever crate is enabled (`time` takes precedence over `chrono`). -datetime_crate = "chrono" +date-time = "chrono" # Or, ensure the macros always prefer `time` # in case new date/time crates are added in the future: -# datetime_crate = "time" +# date-time = "time" + + # Set global overrides for mapping SQL types to Rust types. # @@ -38,7 +42,7 @@ datetime_crate = "chrono" # ### Note: Orthogonal to Nullability # These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` # or not. They only override the inner type used. -[macros.type_overrides] +[macros.type-overrides] # Override a built-in type (map all `UUID` columns to `crate::types::MyUuid`) 'uuid' = "crate::types::MyUuid" @@ -67,7 +71,7 @@ datetime_crate = "chrono" # Quoted schema and type name '"Foo"."Bar"' = "crate::schema::foo::Bar" -# Set per-column overrides for mapping SQL types to Rust types. +# Set per-table and per-column overrides for mapping SQL types to Rust types. # # Note: table name is required in the header. # @@ -76,7 +80,7 @@ datetime_crate = "chrono" # ### Note: Orthogonal to Nullability # These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` # or not. They only override the inner type used. -[macros.column_overrides.'foo'] +[macros.table-overrides.'foo'] # Map column `bar` of table `foo` to Rust type `crate::types::Foo`: 'bar' = "crate::types::Bar" @@ -88,17 +92,17 @@ datetime_crate = "chrono" # "Bar" = "crate::types::Bar" # Table name may be quoted (note the wrapping single-quotes) -[macros.column_overrides.'"Foo"'] +[macros.table-overrides.'"Foo"'] 'bar' = "crate::types::Bar" '"Bar"' = "crate::types::Bar" # Table name may also be schema-qualified. # Note how the dot is inside the quotes. -[macros.column_overrides.'my_schema.my_table'] +[macros.table-overrides.'my_schema.my_table'] 'my_column' = "crate::types::MyType" # Quoted schema, table, and column names -[macros.column_overrides.'"My Schema"."My Table"'] +[macros.table-overrides.'"My Schema"."My Table"'] '"My Column"' = "crate::types::MyType" ############################################################################################### @@ -130,12 +134,12 @@ datetime_crate = "chrono" # You should create the new table as a copy of the existing migrations table (with contents!), # and be sure all instances of your application have been migrated to the new # table before deleting the old one. -table_name = "foo._sqlx_migrations" +table-name = "foo._sqlx_migrations" # Override the directory used for migrations files. # # Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`. -migrations_dir = "foo/migrations" +migrations-dir = "foo/migrations" # Specify characters that should be ignored when hashing migrations. # @@ -148,32 +152,34 @@ migrations_dir = "foo/migrations" # change the output of the hash. # # This may require manual rectification for deployed databases. -# ignored_chars = [] +# ignored-chars = [] # Ignore Carriage Returns (`` | `\r`) # Note that the TOML format requires double-quoted strings to process escapes. -# ignored_chars = ["\r"] +# ignored-chars = ["\r"] # Ignore common whitespace characters (beware syntatically significant whitespace!) # Space, tab, CR, LF, zero-width non-breaking space (U+FEFF) # # U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded, # where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark -ignored_chars = [" ", "\t", "\r", "\n", "\uFEFF"] +ignored-chars = [" ", "\t", "\r", "\n", "\uFEFF"] +# Set default options for new migrations. +[migrate.defaults] # Specify reversible migrations by default (for `sqlx migrate create`). # # Defaults to "inferred": uses the type of the last migration, or "simple" otherwise. -default_type = "reversible" +migration-type = "reversible" # Specify simple (non-reversible) migrations by default. -# default_type = "simple" +# migration-type = "simple" # Specify sequential versioning by default (for `sqlx migrate create`). # # Defaults to "inferred": guesses the versioning scheme from the latest migrations, # or "timestamp" otherwise. -default_versioning = "sequential" +migration-versioning = "sequential" # Specify timestamp versioning by default. -# default_versioning = "timestamp" +# migration-versioning = "timestamp" diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs index e5033bb459..6c2883d58b 100644 --- a/sqlx-core/src/config/tests.rs +++ b/sqlx-core/src/config/tests.rs @@ -8,11 +8,7 @@ fn reference_parses_as_config() { .unwrap_or_else(|e| panic!("expected reference.toml to parse as Config: {e}")); assert_common_config(&config.common); - - #[cfg(feature = "config-macros")] assert_macros_config(&config.macros); - - #[cfg(feature = "config-migrate")] assert_migrate_config(&config.migrate); } @@ -23,7 +19,7 @@ fn assert_common_config(config: &config::common::Config) { fn assert_macros_config(config: &config::macros::Config) { use config::macros::*; - assert_eq!(config.datetime_crate, DateTimeCrate::Chrono); + assert_eq!(config.preferred_crates.date_time, DateTimeCrate::Chrono); // Type overrides // Don't need to cover everything, just some important canaries. @@ -83,6 +79,6 @@ fn assert_migrate_config(config: &config::migrate::Config) { assert_eq!(config.ignored_chars, ignored_chars); - assert_eq!(config.default_type, DefaultMigrationType::Reversible); - assert_eq!(config.default_versioning, DefaultVersioning::Sequential); + assert_eq!(config.defaults.migration_type, DefaultMigrationType::Reversible); + assert_eq!(config.defaults.migration_versioning, DefaultVersioning::Sequential); } From 13f6ef0ab060bf2b9c0b6eea4e130f1f48d05262 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 19 Sep 2024 22:54:48 -0700 Subject: [PATCH 08/30] feat: make macros aware of `macros.preferred-crates` --- sqlx-core/src/column.rs | 22 +- sqlx-core/src/config/common.rs | 4 +- sqlx-core/src/config/macros.rs | 39 ++- sqlx-core/src/config/migrate.rs | 4 +- sqlx-core/src/config/mod.rs | 23 +- sqlx-core/src/config/reference.toml | 11 +- sqlx-core/src/config/tests.rs | 13 +- sqlx-core/src/type_checking.rs | 298 ++++++++++++++++++++-- sqlx-macros-core/src/query/args.rs | 78 ++++-- sqlx-macros-core/src/query/mod.rs | 38 ++- sqlx-macros-core/src/query/output.rs | 143 +++++++---- sqlx-mysql/src/connection/executor.rs | 6 +- sqlx-mysql/src/protocol/text/column.rs | 2 +- sqlx-mysql/src/type_checking.rs | 58 ++--- sqlx-postgres/src/column.rs | 4 +- sqlx-postgres/src/connection/describe.rs | 104 ++++---- sqlx-postgres/src/connection/establish.rs | 4 +- sqlx-postgres/src/type_checking.rs | 206 +++++++-------- sqlx-sqlite/src/column.rs | 2 +- sqlx-sqlite/src/connection/describe.rs | 2 +- sqlx-sqlite/src/statement/handle.rs | 33 +-- sqlx-sqlite/src/statement/virtual.rs | 1 + sqlx-sqlite/src/type_checking.rs | 45 ++-- src/lib.rs | 24 ++ 24 files changed, 800 insertions(+), 364 deletions(-) diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 7483375765..fddc048c4b 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -23,15 +23,17 @@ pub trait Column: 'static + Send + Sync + Debug { fn type_info(&self) -> &::TypeInfo; /// If this column comes from a table, return the table and original column name. - /// + /// /// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression /// or else the source table could not be determined. - /// + /// /// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information, /// or has not overridden this method. - // This method returns an owned value instead of a reference, + // This method returns an owned value instead of a reference, // to give the implementor more flexibility. - fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown } + fn origin(&self) -> ColumnOrigin { + ColumnOrigin::Unknown + } } /// A [`Column`] that originates from a table. @@ -44,20 +46,20 @@ pub struct TableColumn { pub name: Arc, } -/// The possible statuses for our knowledge of the origin of a [`Column`]. +/// The possible statuses for our knowledge of the origin of a [`Column`]. #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub enum ColumnOrigin { - /// The column is known to originate from a table. - /// - /// Included is the table name and original column name. + /// The column is known to originate from a table. + /// + /// Included is the table name and original column name. Table(TableColumn), /// The column originates from an expression, or else its origin could not be determined. Expression, /// The database driver does not know the column origin at this time. - /// + /// /// This may happen if: - /// * The connection is in the middle of executing a query, + /// * The connection is in the middle of executing a query, /// and cannot query the catalog to fetch this information. /// * The connection does not have access to the database catalog. /// * The implementation of [`Column`] did not override [`Column::origin()`]. diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs index c09ed80d7f..d2bf639e5f 100644 --- a/sqlx-core/src/config/common.rs +++ b/sqlx-core/src/config/common.rs @@ -44,6 +44,6 @@ pub struct Config { impl Config { pub fn database_url_var(&self) -> &str { - self.database_url_var.as_deref().unwrap_or("DATABASE_URL") + self.database_url_var.as_deref().unwrap_or("DATABASE_URL") } -} \ No newline at end of file +} diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs index 9f4cf4524f..19e5f42fa0 100644 --- a/sqlx-core/src/config/macros.rs +++ b/sqlx-core/src/config/macros.rs @@ -3,13 +3,13 @@ use std::collections::BTreeMap; /// Configuration for the `query!()` family of macros. #[derive(Debug, Default)] #[cfg_attr( - feature = "sqlx-toml", - derive(serde::Deserialize), + feature = "sqlx-toml", + derive(serde::Deserialize), serde(default, rename_all = "kebab-case") )] pub struct Config { /// Specify which crates' types to use when types from multiple crates apply. - /// + /// /// See [`PreferredCrates`] for details. pub preferred_crates: PreferredCrates, @@ -18,6 +18,12 @@ pub struct Config { /// Default type mappings are defined by the database driver. /// Refer to the `sqlx::types` module for details. /// + /// ## Note: Case-Sensitive + /// Currently, the case of the type name MUST match the name SQLx knows it by. + /// Built-in types are spelled in all-uppercase to match SQL convention. + /// + /// However, user-created types in Postgres are all-lowercase unless quoted. + /// /// ## Note: Orthogonal to Nullability /// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>` /// or not. They only override the inner type used. @@ -63,7 +69,7 @@ pub struct Config { /// ```toml /// [macros.type-overrides] /// # Override a built-in type - /// 'uuid' = "crate::types::MyUuid" + /// 'UUID' = "crate::types::MyUuid" /// /// # Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) /// # (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) @@ -132,6 +138,8 @@ pub struct Config { /// ``` /// /// (See `Note` section above for details.) + // TODO: allow specifying different types for input vs output + // e.g. to accept `&[T]` on input but output `Vec` pub type_overrides: BTreeMap, /// Specify per-table and per-column overrides for mapping SQL types to Rust types. @@ -221,7 +229,7 @@ pub struct Config { #[cfg_attr( feature = "sqlx-toml", derive(serde::Deserialize), - serde(rename_all = "kebab-case") + serde(default, rename_all = "kebab-case") )] pub struct PreferredCrates { /// Specify the crate to use for mapping date/time types to Rust. @@ -360,6 +368,7 @@ pub type RustType = Box; impl Config { /// Get the override for a given type name (optionally schema-qualified). pub fn type_override(&self, type_name: &str) -> Option<&str> { + // TODO: make this case-insensitive self.type_overrides.get(type_name).map(|s| &**s) } @@ -378,6 +387,15 @@ impl DateTimeCrate { pub fn is_inferred(&self) -> bool { *self == Self::Inferred } + + #[inline(always)] + pub fn crate_name(&self) -> Option<&str> { + match self { + Self::Inferred => None, + Self::Chrono => Some("chrono"), + Self::Time => Some("time"), + } + } } impl NumericCrate { @@ -386,4 +404,13 @@ impl NumericCrate { pub fn is_inferred(&self) -> bool { *self == Self::Inferred } -} \ No newline at end of file + + #[inline(always)] + pub fn crate_name(&self) -> Option<&str> { + match self { + Self::Inferred => None, + Self::BigDecimal => Some("bigdecimal"), + Self::RustDecimal => Some("rust_decimal"), + } + } +} diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index d0e55b35d8..64529f9f02 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -14,8 +14,8 @@ use std::collections::BTreeSet; /// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_. #[derive(Debug, Default)] #[cfg_attr( - feature = "sqlx-toml", - derive(serde::Deserialize), + feature = "sqlx-toml", + derive(serde::Deserialize), serde(default, rename_all = "kebab-case") )] pub struct Config { diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index 696752a51b..b3afd9ea1b 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -86,9 +86,7 @@ pub enum ConfigError { /// No configuration file was found. Not necessarily fatal. #[error("config file {path:?} not found")] - NotFound { - path: PathBuf, - }, + NotFound { path: PathBuf }, /// An I/O error occurred while attempting to read the config file at `path`. /// @@ -103,7 +101,7 @@ pub enum ConfigError { /// An error in the TOML was encountered while parsing the config file at `path`. /// /// The error gives line numbers and context when printed with `Display`/`ToString`. - /// + /// /// Only returned if the `sqlx-toml` feature is enabled. #[error("error parsing config file {path:?}")] Parse { @@ -115,14 +113,12 @@ pub enum ConfigError { /// A `sqlx.toml` file was found or specified, but the `sqlx-toml` feature is not enabled. #[error("SQLx found config file at {path:?} but the `sqlx-toml` feature was not enabled")] - ParseDisabled { - path: PathBuf - }, + ParseDisabled { path: PathBuf }, } impl ConfigError { /// Create a [`ConfigError`] from a [`std::io::Error`]. - /// + /// /// Maps to either `NotFound` or `Io`. pub fn from_io(path: PathBuf, error: io::Error) -> Self { if error.kind() == io::ErrorKind::NotFound { @@ -131,7 +127,7 @@ impl ConfigError { Self::Io { path, error } } } - + /// If this error means the file was not found, return the path that was attempted. pub fn not_found_path(&self) -> Option<&Path> { if let Self::NotFound { path } = self { @@ -227,15 +223,18 @@ impl Config { // Motivation: https://github.com/toml-rs/toml/issues/761 tracing::debug!("read config TOML from {path:?}:\n{toml_s}"); - toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { path, error: Box::new(error) }) + toml::from_str(&toml_s).map_err(|error| ConfigError::Parse { + path, + error: Box::new(error), + }) } - + #[cfg(not(feature = "sqlx-toml"))] fn read_from(path: PathBuf) -> Result { match path.try_exists() { Ok(true) => Err(ConfigError::ParseDisabled { path }), Ok(false) => Err(ConfigError::NotFound { path }), - Err(e) => Err(ConfigError::from_io(path, e)) + Err(e) => Err(ConfigError::from_io(path, e)), } } } diff --git a/sqlx-core/src/config/reference.toml b/sqlx-core/src/config/reference.toml index e042824c72..77833fb5a8 100644 --- a/sqlx-core/src/config/reference.toml +++ b/sqlx-core/src/config/reference.toml @@ -30,7 +30,14 @@ date-time = "chrono" # in case new date/time crates are added in the future: # date-time = "time" +# Force the macros to use the `rust_decimal` crate for `NUMERIC`, even if `bigdecimal` is enabled. +# +# Defaults to "inferred": use whichever crate is enabled (`bigdecimal` takes precedence over `rust_decimal`). +numeric = "rust_decimal" +# Or, ensure the macros always prefer `bigdecimal` +# in case new decimal crates are added in the future: +# numeric = "bigdecimal" # Set global overrides for mapping SQL types to Rust types. # @@ -44,7 +51,9 @@ date-time = "chrono" # or not. They only override the inner type used. [macros.type-overrides] # Override a built-in type (map all `UUID` columns to `crate::types::MyUuid`) -'uuid' = "crate::types::MyUuid" +# Note: currently, the case of the type name MUST match. +# Built-in types are spelled in all-uppercase to match SQL convention. +'UUID' = "crate::types::MyUuid" # Support an external or custom wrapper type (e.g. from the `isn` Postgres extension) # (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING) diff --git a/sqlx-core/src/config/tests.rs b/sqlx-core/src/config/tests.rs index 6c2883d58b..0b0b590919 100644 --- a/sqlx-core/src/config/tests.rs +++ b/sqlx-core/src/config/tests.rs @@ -20,9 +20,12 @@ fn assert_macros_config(config: &config::macros::Config) { use config::macros::*; assert_eq!(config.preferred_crates.date_time, DateTimeCrate::Chrono); + assert_eq!(config.preferred_crates.numeric, NumericCrate::RustDecimal); // Type overrides // Don't need to cover everything, just some important canaries. + assert_eq!(config.type_override("UUID"), Some("crate::types::MyUuid")); + assert_eq!(config.type_override("foo"), Some("crate::types::Foo")); assert_eq!(config.type_override(r#""Bar""#), Some("crate::types::Bar"),); @@ -79,6 +82,12 @@ fn assert_migrate_config(config: &config::migrate::Config) { assert_eq!(config.ignored_chars, ignored_chars); - assert_eq!(config.defaults.migration_type, DefaultMigrationType::Reversible); - assert_eq!(config.defaults.migration_versioning, DefaultVersioning::Sequential); + assert_eq!( + config.defaults.migration_type, + DefaultMigrationType::Reversible + ); + assert_eq!( + config.defaults.migration_versioning, + DefaultVersioning::Sequential + ); } diff --git a/sqlx-core/src/type_checking.rs b/sqlx-core/src/type_checking.rs index 384d15f42c..d3d4a4c7af 100644 --- a/sqlx-core/src/type_checking.rs +++ b/sqlx-core/src/type_checking.rs @@ -1,3 +1,4 @@ +use crate::config::macros::PreferredCrates; use crate::database::Database; use crate::decode::Decode; use crate::type_info::TypeInfo; @@ -26,12 +27,18 @@ pub trait TypeChecking: Database { /// /// If the type has a borrowed equivalent suitable for query parameters, /// this is that borrowed type. - fn param_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; + fn param_type_for_id( + id: &Self::TypeInfo, + preferred_crates: &PreferredCrates, + ) -> Result<&'static str, Error>; /// Get the full path of the Rust type that corresponds to the given `TypeInfo`, if applicable. /// /// Always returns the owned version of the type, suitable for decoding from `Row`. - fn return_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; + fn return_type_for_id( + id: &Self::TypeInfo, + preferred_crates: &PreferredCrates, + ) -> Result<&'static str, Error>; /// Get the name of the Cargo feature gate that must be enabled to process the given `TypeInfo`, /// if applicable. @@ -43,6 +50,18 @@ pub trait TypeChecking: Database { fn fmt_value_debug(value: &::Value) -> FmtValue<'_, Self>; } +pub type Result = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("no built-in mapping found for SQL type; a type override may be required")] + NoMappingFound, + #[error("Cargo feature for configured `macros.preferred-crates.date-time` not enabled")] + DateTimeCrateFeatureNotEnabled, + #[error("Cargo feature for configured `macros.preferred-crates.numeric` not enabled")] + NumericCrateFeatureNotEnabled, +} + /// An adapter for [`Value`] which attempts to decode the value and format it when printed using [`Debug`]. pub struct FmtValue<'v, DB> where @@ -134,36 +153,256 @@ macro_rules! impl_type_checking { }, ParamChecking::$param_checking:ident, feature-types: $ty_info:ident => $get_gate:expr, + datetime-types: { + chrono: { + $($chrono_ty:ty $(| $chrono_input:ty)?),*$(,)? + }, + time: { + $($time_ty:ty $(| $time_input:ty)?),*$(,)? + }, + }, + numeric-types: { + bigdecimal: { + $($bigdecimal_ty:ty $(| $bigdecimal_input:ty)?),*$(,)? + }, + rust_decimal: { + $($rust_decimal_ty:ty $(| $rust_decimal_input:ty)?),*$(,)? + }, + }, ) => { impl $crate::type_checking::TypeChecking for $database { const PARAM_CHECKING: $crate::type_checking::ParamChecking = $crate::type_checking::ParamChecking::$param_checking; - fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { - match () { + fn param_type_for_id( + info: &Self::TypeInfo, + preferred_crates: &$crate::config::macros::PreferredCrates, + ) -> Result<&'static str, $crate::type_checking::Error> { + use $crate::config::macros::{DateTimeCrate, NumericCrate}; + use $crate::type_checking::Error; + + // Check `macros.preferred-crates.date-time` + // + // Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled. + // Any crates added later should be _lower_ priority than `chrono` to avoid breakages. + // ---------------------------------------- + #[cfg(feature = "time")] + if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) { + $( + if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($time_ty $(, $time_input)?)); + } + )* + + $( + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($time_ty $(, $time_input)?)); + } + )* + } + + #[cfg(not(feature = "time"))] + if preferred_crates.date_time == DateTimeCrate::Time { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + #[cfg(feature = "chrono")] + if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) { + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?)); + } + )* + $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some($crate::select_input_type!($ty $(, $input)?)), + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?)); + } )* + } + + #[cfg(not(feature = "chrono"))] + if preferred_crates.date_time == DateTimeCrate::Chrono { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + // Check `macros.preferred-crates.numeric` + // + // Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if + // both are enabled. + // ---------------------------------------- + #[cfg(feature = "bigdecimal")] + if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) { $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some($crate::select_input_type!($ty $(, $input)?)), + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?)); + } + )* + + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?)); + } )* - _ => None } + + #[cfg(not(feature = "bigdecimal"))] + if preferred_crates.numeric == NumericCrate::BigDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + #[cfg(feature = "rust_decimal")] + if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) { + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + } + + #[cfg(not(feature = "rust_decimal"))] + if preferred_crates.numeric == NumericCrate::RustDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + // Check all other types + // --------------------- + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($ty $(, $input)?)); + } + )* + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($ty $(, $input)?)); + } + )* + + Err(Error::NoMappingFound) } - fn return_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { - match () { + fn return_type_for_id( + info: &Self::TypeInfo, + preferred_crates: &$crate::config::macros::PreferredCrates, + ) -> Result<&'static str, $crate::type_checking::Error> { + use $crate::config::macros::{DateTimeCrate, NumericCrate}; + use $crate::type_checking::Error; + + // Check `macros.preferred-crates.date-time` + // + // Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled. + // Any crates added later should be _lower_ priority than `chrono` to avoid breakages. + // ---------------------------------------- + #[cfg(feature = "time")] + if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) { + $( + if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($time_ty)); + } + )* + + $( + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($time_ty)); + } + )* + } + + #[cfg(not(feature = "time"))] + if preferred_crates.date_time == DateTimeCrate::Time { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + #[cfg(feature = "chrono")] + if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) { $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(stringify!($ty)), + if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($chrono_ty)); + } )* + + $( + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($chrono_ty)); + } + )* + } + + #[cfg(not(feature = "chrono"))] + if preferred_crates.date_time == DateTimeCrate::Chrono { + return Err(Error::DateTimeCrateFeatureNotEnabled); + } + + // Check `macros.preferred-crates.numeric` + // + // Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if + // both are enabled. + // ---------------------------------------- + #[cfg(feature = "bigdecimal")] + if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) { + $( + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($bigdecimal_ty)); + } + )* + $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some(stringify!($ty)), + if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($bigdecimal_ty)); + } )* - _ => None } + + #[cfg(not(feature = "bigdecimal"))] + if preferred_crates.numeric == NumericCrate::BigDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + #[cfg(feature = "rust_decimal")] + if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) { + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + + $( + if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?)); + } + )* + } + + #[cfg(not(feature = "rust_decimal"))] + if preferred_crates.numeric == NumericCrate::RustDecimal { + return Err(Error::NumericCrateFeatureNotEnabled); + } + + // Check all other types + // --------------------- + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info { + return Ok(stringify!($ty)); + } + )* + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(info) { + return Ok(stringify!($ty)); + } + )* + + Err(Error::NoMappingFound) } fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> { @@ -175,13 +414,32 @@ macro_rules! impl_type_checking { let info = value.type_info(); - match () { + #[cfg(feature = "time")] + { + $( + if <$time_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$time_ty>(value); + } + )* + } + + #[cfg(feature = "chrono")] + { $( - $(#[$meta])? - _ if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) => $crate::type_checking::FmtValue::debug::<$ty>(value), + if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$chrono_ty>(value); + } )* - _ => $crate::type_checking::FmtValue::unknown(value), } + + $( + $(#[$meta])? + if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) { + return $crate::type_checking::FmtValue::debug::<$ty>(value); + } + )* + + $crate::type_checking::FmtValue::unknown(value) } } }; diff --git a/sqlx-macros-core/src/query/args.rs b/sqlx-macros-core/src/query/args.rs index ec17aeff65..1ddc5e984c 100644 --- a/sqlx-macros-core/src/query/args.rs +++ b/sqlx-macros-core/src/query/args.rs @@ -3,7 +3,10 @@ use crate::query::QueryMacroInput; use either::Either; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; +use sqlx_core::config::Config; use sqlx_core::describe::Describe; +use sqlx_core::type_checking; +use sqlx_core::type_info::TypeInfo; use syn::spanned::Spanned; use syn::{Expr, ExprCast, ExprGroup, Type}; @@ -11,6 +14,7 @@ use syn::{Expr, ExprCast, ExprGroup, Type}; /// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, + config: &Config, info: &Describe, ) -> crate::Result { let db_path = DB::db_path(); @@ -55,22 +59,7 @@ pub fn quote_args( return Ok(quote!()); } - let param_ty = - DB::param_type_for_id(param_ty) - .ok_or_else(|| { - if let Some(feature_gate) = DB::get_feature_gate(param_ty) { - format!( - "optional sqlx feature `{}` required for type {} of param #{}", - feature_gate, - param_ty, - i + 1, - ) - } else { - format!("unsupported type {} for param #{}", param_ty, i + 1) - } - })? - .parse::() - .map_err(|_| format!("Rust type mapping for {param_ty} not parsable"))?; + let param_ty = get_param_type::(param_ty, config, i)?; Ok(quote_spanned!(expr.span() => // this shouldn't actually run @@ -115,6 +104,63 @@ pub fn quote_args( }) } +fn get_param_type( + param_ty: &DB::TypeInfo, + config: &Config, + i: usize, +) -> crate::Result { + if let Some(type_override) = config.macros.type_override(param_ty.name()) { + return Ok(type_override.parse()?); + } + + let err = match DB::param_type_for_id(param_ty, &config.macros.preferred_crates) { + Ok(t) => return Ok(t.parse()?), + Err(e) => e, + }; + + let param_num = i + 1; + + let message = match err { + type_checking::Error::NoMappingFound => { + if let Some(feature_gate) = DB::get_feature_gate(param_ty) { + format!( + "optional sqlx feature `{feature_gate}` required for type {param_ty} of param #{param_num}", + ) + } else { + format!("unsupported type {param_ty} for param #{param_num}") + } + } + type_checking::Error::DateTimeCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .date_time + .crate_name() + .expect("BUG: got feature-not-enabled error for DateTimeCrate::Inferred"); + + format!( + "SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \ + (configured by `macros.preferred-crates.date-time` in sqlx.toml)", + ) + } + type_checking::Error::NumericCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .numeric + .crate_name() + .expect("BUG: got feature-not-enabled error for NumericCrate::Inferred"); + + format!( + "SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \ + (configured by `macros.preferred-crates.numeric` in sqlx.toml)", + ) + } + }; + + Err(message.into()) +} + fn get_type_override(expr: &Expr) -> Option<&Type> { match expr { Expr::Group(group) => get_type_override(&group.expr), diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index 190d272d14..37592d4f77 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -15,8 +15,8 @@ use crate::database::DatabaseExt; use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; use either::Either; -use url::Url; use sqlx_core::config::Config; +use url::Url; mod args; mod data; @@ -139,11 +139,9 @@ static METADATA: Lazy = Lazy::new(|| { let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - - let var_name = Config::from_crate() - .common - .database_url_var(); - + + let var_name = Config::from_crate().common.database_url_var(); + let database_url = env(var_name).ok(); Metadata { @@ -251,6 +249,8 @@ fn expand_with_data( where Describe: DescribeExt, { + let config = Config::from_crate(); + // validate at the minimum that our args match the query's input parameters let num_parameters = match data.describe.parameters() { Some(Either::Left(params)) => Some(params.len()), @@ -267,7 +267,7 @@ where } } - let args_tokens = args::quote_args(&input, &data.describe)?; + let args_tokens = args::quote_args(&input, config, &data.describe)?; let query_args = format_ident!("query_args"); @@ -286,7 +286,7 @@ where } else { match input.record_type { RecordType::Generated => { - let columns = output::columns_to_rust::(&data.describe)?; + let columns = output::columns_to_rust::(&data.describe, config)?; let record_name: Type = syn::parse_str("Record").unwrap(); @@ -322,22 +322,40 @@ where record_tokens } RecordType::Given(ref out_ty) => { - let columns = output::columns_to_rust::(&data.describe)?; + let columns = output::columns_to_rust::(&data.describe, config)?; output::quote_query_as::(&input, out_ty, &query_args, &columns) } RecordType::Scalar => { - output::quote_query_scalar::(&input, &query_args, &data.describe)? + output::quote_query_scalar::(&input, config, &query_args, &data.describe)? } } }; + let mut warnings = TokenStream::new(); + + if config.macros.preferred_crates.date_time.is_inferred() { + // Warns if the date-time crate is inferred but both `chrono` and `time` are enabled + warnings.extend(quote! { + ::sqlx::warn_on_ambiguous_inferred_date_time_crate(); + }); + } + + if config.macros.preferred_crates.numeric.is_inferred() { + // Warns if the numeric crate is inferred but both `bigdecimal` and `rust_decimal` are enabled + warnings.extend(quote! { + ::sqlx::warn_on_ambiguous_inferred_numeric_crate(); + }); + } + let ret_tokens = quote! { { #[allow(clippy::all)] { use ::sqlx::Arguments as _; + #warnings + #args_tokens #output diff --git a/sqlx-macros-core/src/query/output.rs b/sqlx-macros-core/src/query/output.rs index d9dc79a366..1a145e3a75 100644 --- a/sqlx-macros-core/src/query/output.rs +++ b/sqlx-macros-core/src/query/output.rs @@ -8,12 +8,13 @@ use sqlx_core::describe::Describe; use crate::database::DatabaseExt; use crate::query::QueryMacroInput; +use sqlx_core::config::Config; +use sqlx_core::type_checking; use sqlx_core::type_checking::TypeChecking; +use sqlx_core::type_info::TypeInfo; use std::fmt::{self, Display, Formatter}; use syn::parse::{Parse, ParseStream}; use syn::Token; -use sqlx_core::config::Config; -use sqlx_core::type_info::TypeInfo; pub struct RustColumn { pub(super) ident: Ident, @@ -78,13 +79,20 @@ impl Display for DisplayColumn<'_> { } } -pub fn columns_to_rust(describe: &Describe) -> crate::Result> { +pub fn columns_to_rust( + describe: &Describe, + config: &Config, +) -> crate::Result> { (0..describe.columns().len()) - .map(|i| column_to_rust(describe, i)) + .map(|i| column_to_rust(describe, config, i)) .collect::>>() } -fn column_to_rust(describe: &Describe, i: usize) -> crate::Result { +fn column_to_rust( + describe: &Describe, + config: &Config, + i: usize, +) -> crate::Result { let column = &describe.columns()[i]; // add raw prefix to all identifiers @@ -108,7 +116,7 @@ fn column_to_rust(describe: &Describe, i: usize) -> crate:: (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, (ColumnTypeOverride::None, _) => { - let type_ = get_column_type::(i, column); + let type_ = get_column_type::(config, i, column); if !nullable { ColumnType::Exact(type_) } else { @@ -195,6 +203,7 @@ pub fn quote_query_as( pub fn quote_query_scalar( input: &QueryMacroInput, + config: &Config, bind_args: &Ident, describe: &Describe, ) -> crate::Result { @@ -209,10 +218,10 @@ pub fn quote_query_scalar( } // attempt to parse a column override, otherwise fall back to the inferred type of the column - let ty = if let Ok(rust_col) = column_to_rust(describe, 0) { + let ty = if let Ok(rust_col) = column_to_rust(describe, config, 0) { rust_col.type_.to_token_stream() } else if input.checked { - let ty = get_column_type::(0, &columns[0]); + let ty = get_column_type::(config, 0, &columns[0]); if describe.nullable(0).unwrap_or(true) { quote! { ::std::option::Option<#ty> } } else { @@ -230,52 +239,92 @@ pub fn quote_query_scalar( }) } -fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { +fn get_column_type(config: &Config, i: usize, column: &DB::Column) -> TokenStream { if let ColumnOrigin::Table(origin) = column.origin() { - if let Some(column_override) = Config::from_crate() - .macros - .column_override(&origin.table, &origin.name) - { + if let Some(column_override) = config.macros.column_override(&origin.table, &origin.name) { return column_override.parse().unwrap(); } } - + let type_info = column.type_info(); - if let Some(type_override) = Config::from_crate() - .macros - .type_override(type_info.name()) - { - return type_override.parse().unwrap(); + if let Some(type_override) = config.macros.type_override(type_info.name()) { + return type_override.parse().unwrap(); } - - ::return_type_for_id(type_info).map_or_else( - || { - let message = - if let Some(feature_gate) = ::get_feature_gate(type_info) { - format!( - "optional sqlx feature `{feat}` required for type {ty} of {col}", - ty = &type_info, - feat = feature_gate, - col = DisplayColumn { - idx: i, - name: column.name() - } - ) - } else { - format!( - "unsupported type {ty} of {col}", - ty = type_info, - col = DisplayColumn { - idx: i, - name: column.name() - } - ) - }; - syn::Error::new(Span::call_site(), message).to_compile_error() - }, - |t| t.parse().unwrap(), - ) + + let err = match ::return_type_for_id( + type_info, + &config.macros.preferred_crates, + ) { + Ok(t) => return t.parse().unwrap(), + Err(e) => e, + }; + + let message = match err { + type_checking::Error::NoMappingFound => { + if let Some(feature_gate) = ::get_feature_gate(type_info) { + format!( + "SQLx feature `{feat}` required for type {ty} of {col}", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } else { + format!( + "no built-in mapping found for type {ty} of {col}; \ + a type override may be required, see documentation for details", + ty = type_info, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + } + type_checking::Error::DateTimeCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .date_time + .crate_name() + .expect("BUG: got feature-not-enabled error for DateTimeCrate::Inferred"); + + format!( + "SQLx feature `{feat}` required for type {ty} of {col} \ + (configured by `macros.preferred-crates.date-time` in sqlx.toml)", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + type_checking::Error::NumericCrateFeatureNotEnabled => { + let feature_gate = config + .macros + .preferred_crates + .numeric + .crate_name() + .expect("BUG: got feature-not-enabled error for NumericCrate::Inferred"); + + format!( + "SQLx feature `{feat}` required for type {ty} of {col} \ + (configured by `macros.preferred-crates.numeric` in sqlx.toml)", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: column.name() + } + ) + } + }; + + syn::Error::new(Span::call_site(), message).to_compile_error() } impl ColumnDecl { diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 6baad5ccab..d0d9cf18c3 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -22,8 +22,8 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; -use std::{borrow::Cow, sync::Arc}; use sqlx_core::column::{ColumnOrigin, TableColumn}; +use std::{borrow::Cow, sync::Arc}; impl MySqlConnection { async fn prepare_statement<'c>( @@ -391,7 +391,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result Result Result<&str, Error> { str::from_utf8(&self.table).map_err(Error::protocol) } - + pub(crate) fn name(&self) -> Result<&str, Error> { str::from_utf8(&self.name).map_err(Error::protocol) } diff --git a/sqlx-mysql/src/type_checking.rs b/sqlx-mysql/src/type_checking.rs index 3f3ce5833e..0bdc84d8c9 100644 --- a/sqlx-mysql/src/type_checking.rs +++ b/sqlx-mysql/src/type_checking.rs @@ -25,41 +25,39 @@ impl_type_checking!( // BINARY, VAR_BINARY, BLOB Vec, - // Types from third-party crates need to be referenced at a known path - // for the macros to work, but we don't want to require the user to add extra dependencies. - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime, - - #[cfg(feature = "time")] - sqlx::types::time::Time, + #[cfg(feature = "json")] + sqlx::types::JsonValue, + }, + ParamChecking::Weak, + feature-types: info => info.__type_feature_gate(), + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + sqlx::types::chrono::NaiveTime, - #[cfg(feature = "time")] - sqlx::types::time::Date, + sqlx::types::chrono::NaiveDate, - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, + sqlx::types::chrono::NaiveDateTime, - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, + sqlx::types::chrono::DateTime, + }, + time: { + sqlx::types::time::Time, - #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, + sqlx::types::time::Date, - #[cfg(feature = "rust_decimal")] - sqlx::types::Decimal, + sqlx::types::time::PrimitiveDateTime, - #[cfg(feature = "json")] - sqlx::types::JsonValue, + sqlx::types::time::OffsetDateTime, + }, + }, + numeric-types: { + bigdecimal: { + sqlx::types::BigDecimal, + }, + rust_decimal: { + sqlx::types::Decimal, + }, }, - ParamChecking::Weak, - feature-types: info => info.__type_feature_gate(), ); diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index bd08e27db0..4dd3a1cbd2 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -1,8 +1,8 @@ use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; -pub(crate) use sqlx_core::column::{Column, ColumnIndex}; use sqlx_core::column::ColumnOrigin; +pub(crate) use sqlx_core::column::{Column, ColumnIndex}; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] @@ -13,7 +13,7 @@ pub struct PgColumn { #[cfg_attr(feature = "offline", serde(default))] pub(crate) origin: ColumnOrigin, - + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 53affe5dc3..4decdde5dd 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,4 +1,4 @@ -use std::collections::btree_map; +use crate::connection::TableColumns; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; @@ -12,11 +12,9 @@ use crate::types::Oid; use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; +use sqlx_core::column::{ColumnOrigin, TableColumn}; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; -use sqlx_core::column::{ColumnOrigin, TableColumn}; -use sqlx_core::hash_map; -use crate::connection::TableColumns; /// Describes the type of the `pg_type.typtype` column /// @@ -125,9 +123,12 @@ impl PgConnection { let type_info = self .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) .await?; - - let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { - self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await? + + let origin = if let (Some(relation_oid), Some(attribute_no)) = + (field.relation_id, field.relation_attribute_no) + { + self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch) + .await? } else { ColumnOrigin::Expression }; @@ -200,52 +201,65 @@ impl PgConnection { Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) } } - + async fn maybe_fetch_column_origin( - &mut self, - relation_id: Oid, + &mut self, + relation_id: Oid, attribute_no: i16, should_fetch: bool, ) -> Result { - let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) { - hash_map::Entry::Occupied(table_columns) => { - table_columns.into_mut() - }, - hash_map::Entry::Vacant(vacant) => { - if !should_fetch { return Ok(ColumnOrigin::Unknown); } - - let table_name: String = query_scalar("SELECT $1::oid::regclass::text") - .bind(relation_id) - .fetch_one(&mut *self) - .await?; - - vacant.insert(TableColumns { - table_name: table_name.into(), - columns: Default::default(), + if let Some(origin) = + self.cache_table_to_column_names + .get(&relation_id) + .and_then(|table_columns| { + let column_name = table_columns.columns.get(&attribute_no).cloned()?; + + Some(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name, + })) }) - } - }; - - let column_name = match table_columns.columns.entry(attribute_no) { - btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()), - btree_map::Entry::Vacant(vacant) => { - if !should_fetch { return Ok(ColumnOrigin::Unknown); } - - let column_name: String = query_scalar( - "SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2" - ) - .bind(relation_id) - .bind(attribute_no) - .fetch_one(&mut *self) - .await?; - - Arc::clone(vacant.insert(column_name.into())) - } + { + return Ok(origin); + } + + if !should_fetch { + return Ok(ColumnOrigin::Unknown); + } + + // Looking up the table name _may_ end up being redundant, + // but the round-trip to the server is by far the most expensive part anyway. + let Some((table_name, column_name)): Option<(String, String)> = query_as( + // language=PostgreSQL + "SELECT $1::oid::regclass::text, attname \ + FROM pg_catalog.pg_attribute \ + WHERE attrelid = $1 AND attnum = $2", + ) + .bind(relation_id) + .bind(attribute_no) + .fetch_optional(&mut *self) + .await? + else { + // The column/table doesn't exist anymore for whatever reason. + return Ok(ColumnOrigin::Unknown); }; - + + let table_columns = self + .cache_table_to_column_names + .entry(relation_id) + .or_insert_with(|| TableColumns { + table_name: table_name.into(), + columns: Default::default(), + }); + + let column_name = table_columns + .columns + .entry(attribute_no) + .or_insert(column_name.into()); + Ok(ColumnOrigin::Table(TableColumn { table: table_columns.table_name.clone(), - name: column_name + name: Arc::clone(column_name), })) } diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 1bc4172fbd..684bf26599 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -148,8 +148,8 @@ impl PgConnection { cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), cache_elem_type_to_array: HashMap::new(), - log_settings: options.log_settings.clone(), - }), + cache_table_to_column_names: HashMap::new(), + log_settings: options.log_settings.clone(),}), }) } } diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index eb18c5a999..41661a84bc 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -39,42 +39,6 @@ impl_type_checking!( #[cfg(feature = "uuid")] sqlx::types::Uuid, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgTimeTz, - - #[cfg(feature = "time")] - sqlx::types::time::Time, - - #[cfg(feature = "time")] - sqlx::types::time::Date, - - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, - - #[cfg(feature = "time")] - sqlx::postgres::types::PgTimeTz, - - #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, - - #[cfg(feature = "rust_decimal")] - sqlx::types::Decimal, - #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, @@ -106,36 +70,6 @@ impl_type_checking!( #[cfg(feature = "uuid")] Vec | &[sqlx::types::Uuid], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveTime], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveDate], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec | &[sqlx::types::chrono::NaiveDateTime], - - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | &[sqlx::types::chrono::DateTime<_>], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Time], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Date], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::PrimitiveDateTime], - - #[cfg(feature = "time")] - Vec | &[sqlx::types::time::OffsetDateTime], - - #[cfg(feature = "bigdecimal")] - Vec | &[sqlx::types::BigDecimal], - - #[cfg(feature = "rust_decimal")] - Vec | &[sqlx::types::Decimal], - #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], @@ -152,72 +86,114 @@ impl_type_checking!( sqlx::postgres::types::PgRange, sqlx::postgres::types::PgRange, - #[cfg(feature = "bigdecimal")] - sqlx::postgres::types::PgRange, + // Range arrays - #[cfg(feature = "rust_decimal")] - sqlx::postgres::types::PgRange, + Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::postgres::types::PgRange], + }, + ParamChecking::Strong, + feature-types: info => info.__type_feature_gate(), + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + // Scalar types + sqlx::types::chrono::NaiveTime, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange, + sqlx::types::chrono::NaiveDate, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange, + sqlx::types::chrono::NaiveDateTime, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::postgres::types::PgRange> | - sqlx::postgres::types::PgRange>, + sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + sqlx::postgres::types::PgTimeTz, - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + // Array types + Vec | &[sqlx::types::chrono::NaiveTime], - #[cfg(feature = "time")] - sqlx::postgres::types::PgRange, + Vec | &[sqlx::types::chrono::NaiveDate], - // Range arrays + Vec | &[sqlx::types::chrono::NaiveDateTime], - Vec> | &[sqlx::postgres::types::PgRange], - Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::types::chrono::DateTime<_>], + + // Range types + sqlx::postgres::types::PgRange, + + sqlx::postgres::types::PgRange, + + sqlx::postgres::types::PgRange> | + sqlx::postgres::types::PgRange>, + + // Arrays of ranges + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec>> | + &[sqlx::postgres::types::PgRange>], + }, + time: { + // Scalar types + sqlx::types::time::Time, + + sqlx::types::time::Date, + + sqlx::types::time::PrimitiveDateTime, - #[cfg(feature = "bigdecimal")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::types::time::OffsetDateTime, - #[cfg(feature = "rust_decimal")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgTimeTz, - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | - &[sqlx::postgres::types::PgRange], + // Array types + Vec | &[sqlx::types::time::Time], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec> | - &[sqlx::postgres::types::PgRange], + Vec | &[sqlx::types::time::Date], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec>> | - &[sqlx::postgres::types::PgRange>], + Vec | &[sqlx::types::time::PrimitiveDateTime], - #[cfg(all(feature = "chrono", not(feature = "time")))] - Vec>> | - &[sqlx::postgres::types::PgRange>], + Vec | &[sqlx::types::time::OffsetDateTime], - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + // Range types + sqlx::postgres::types::PgRange, - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgRange, - #[cfg(feature = "time")] - Vec> | - &[sqlx::postgres::types::PgRange], + sqlx::postgres::types::PgRange, + + // Arrays of ranges + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + + Vec> | + &[sqlx::postgres::types::PgRange], + }, + }, + numeric-types: { + bigdecimal: { + sqlx::types::BigDecimal, + + Vec | &[sqlx::types::BigDecimal], + + sqlx::postgres::types::PgRange, + + Vec> | + &[sqlx::postgres::types::PgRange], + }, + rust_decimal: { + sqlx::types::Decimal, + + Vec | &[sqlx::types::Decimal], + + sqlx::postgres::types::PgRange, + + Vec> | + &[sqlx::postgres::types::PgRange], + }, }, - ParamChecking::Strong, - feature-types: info => info.__type_feature_gate(), ); diff --git a/sqlx-sqlite/src/column.rs b/sqlx-sqlite/src/column.rs index 390f3687fb..d319bd46a8 100644 --- a/sqlx-sqlite/src/column.rs +++ b/sqlx-sqlite/src/column.rs @@ -11,7 +11,7 @@ pub struct SqliteColumn { pub(crate) type_info: SqliteTypeInfo, #[cfg_attr(feature = "offline", serde(default))] - pub(crate) origin: ColumnOrigin + pub(crate) origin: ColumnOrigin, } impl Column for SqliteColumn { diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 9ba9f8c3b1..6db81374aa 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -49,7 +49,7 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result ColumnOrigin { - if let Some((table, name)) = - self.column_table_name(index).zip(self.column_origin_name(index)) + if let Some((table, name)) = self + .column_table_name(index) + .zip(self.column_origin_name(index)) { let table: Arc = self .column_db_name(index) @@ -125,20 +126,20 @@ impl StatementHandle { // TODO: check that SQLite returns the names properly quoted if necessary |db| format!("{db}.{table}").into(), ); - + ColumnOrigin::Table(TableColumn { table, - name: name.into() + name: name.into(), }) } else { ColumnOrigin::Expression } } - + fn column_db_name(&self, index: usize) -> Option<&str> { unsafe { let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index)); - + if !db_name.is_null() { Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes())) } else { @@ -170,7 +171,7 @@ impl StatementHandle { } } } - + pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo { SqliteTypeInfo(DataType::from_code(self.column_type(index))) } diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 6be980c36a..345af307a7 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -104,6 +104,7 @@ impl VirtualStatement { ordinal: i, name: name.clone(), type_info, + origin: statement.column_origin(i), }); column_names.insert(name, i); diff --git a/sqlx-sqlite/src/type_checking.rs b/sqlx-sqlite/src/type_checking.rs index e1ac3bc753..97af601c86 100644 --- a/sqlx-sqlite/src/type_checking.rs +++ b/sqlx-sqlite/src/type_checking.rs @@ -1,8 +1,7 @@ +use crate::Sqlite; #[allow(unused_imports)] use sqlx_core as sqlx; -use crate::Sqlite; - // f32 is not included below as REAL represents a floating point value // stored as an 8-byte IEEE floating point number (i.e. an f64) // For more info see: https://www.sqlite.org/datatype3.html#storage_classes_and_datatypes @@ -20,24 +19,6 @@ impl_type_checking!( String, Vec, - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, - - #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, - - #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, - - #[cfg(feature = "time")] - sqlx::types::time::Date, - #[cfg(feature = "uuid")] sqlx::types::Uuid, }, @@ -48,4 +29,28 @@ impl_type_checking!( // The type integrations simply allow the user to skip some intermediate representation, // which is usually TEXT. feature-types: _info => None, + + // The expansion of the macro automatically applies the correct feature name + // and checks `[macros.preferred-crates]` + datetime-types: { + chrono: { + sqlx::types::chrono::NaiveDate, + + sqlx::types::chrono::NaiveDateTime, + + sqlx::types::chrono::DateTime + | sqlx::types::chrono::DateTime<_>, + }, + time: { + sqlx::types::time::OffsetDateTime, + + sqlx::types::time::PrimitiveDateTime, + + sqlx::types::time::Date, + }, + }, + numeric-types: { + bigdecimal: { }, + rust_decimal: { }, + }, ); diff --git a/src/lib.rs b/src/lib.rs index a357753b96..7c10b85213 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,3 +168,27 @@ pub mod prelude { #[cfg(feature = "_unstable-doc")] pub use sqlx_core::config; + +#[doc(hidden)] +#[cfg_attr( + all(feature = "chrono", feature = "time"), + deprecated = "SQLx has both `chrono` and `time` features enabled, \ + which presents an ambiguity when the `query!()` macros are mapping date/time types. \ + The `query!()` macros prefer types from `time` by default, \ + but this behavior should not be relied upon; \ + to resolve the ambiguity, we recommend specifying the preferred crate in a `sqlx.toml` file: \ + https://docs.rs/sqlx/latest/sqlx/config/macros/PreferredCrates.html#field.date_time" +)] +pub fn warn_on_ambiguous_inferred_date_time_crate() {} + +#[doc(hidden)] +#[cfg_attr( + all(feature = "bigdecimal", feature = "rust_decimal"), + deprecated = "SQLx has both `bigdecimal` and `rust_decimal` features enabled, \ + which presents an ambiguity when the `query!()` macros are mapping `NUMERIC`. \ + The `query!()` macros prefer `bigdecimal::BigDecimal` by default, \ + but this behavior should not be relied upon; \ + to resolve the ambiguity, we recommend specifying the preferred crate in a `sqlx.toml` file: \ + https://docs.rs/sqlx/latest/sqlx/config/macros/PreferredCrates.html#field.numeric" +)] +pub fn warn_on_ambiguous_inferred_numeric_crate() {} From 65ef27f70c02411f0ba1bc03d24709e546e176fb Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 20 Sep 2024 00:46:43 -0700 Subject: [PATCH 09/30] feat: make `sqlx-cli` aware of `database-url-var` --- sqlx-cli/src/database.rs | 8 ++-- sqlx-cli/src/lib.rs | 74 +++++++++++++++++++++++++++++-------- sqlx-cli/src/opt.rs | 47 +++++++++++++++++++---- sqlx-core/src/config/mod.rs | 66 ++++++++++++++++++++------------- src/lib.rs | 6 +++ 5 files changed, 148 insertions(+), 53 deletions(-) diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 7a2056ab35..53834c111e 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -17,14 +17,14 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { std::sync::atomic::Ordering::Release, ); - Any::create_database(connect_opts.required_db_url()?).await?; + Any::create_database(connect_opts.expect_db_url()?).await?; } Ok(()) } pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> anyhow::Result<()> { - if confirm && !ask_to_continue_drop(connect_opts.required_db_url()?) { + if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?) { return Ok(()); } @@ -34,9 +34,9 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any if exists { if force { - Any::force_drop_database(connect_opts.required_db_url()?).await?; + Any::force_drop_database(connect_opts.expect_db_url()?).await?; } else { - Any::drop_database(connect_opts.required_db_url()?).await?; + Any::drop_database(connect_opts.expect_db_url()?).await?; } } diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index bfd71e4bc1..d08b949c78 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,7 +1,8 @@ use std::io; +use std::path::{Path, PathBuf}; use std::time::Duration; -use anyhow::Result; +use anyhow::{Context, Result}; use futures::{Future, TryFutureExt}; use sqlx::{AnyConnection, Connection}; @@ -20,7 +21,12 @@ mod prepare; pub use crate::opt::Opt; +pub use sqlx::_unstable::config; +use crate::config::Config; + pub async fn run(opt: Opt) -> Result<()> { + let config = config_from_current_dir()?; + match opt.command { Command::Migrate(migrate) => match migrate.command { MigrateCommand::Add { @@ -34,9 +40,11 @@ pub async fn run(opt: Opt) -> Result<()> { source, dry_run, ignore_missing, - connect_opts, + mut connect_opts, target_version, } => { + connect_opts.populate_db_url(config)?; + migrate::run( &source, &connect_opts, @@ -50,9 +58,11 @@ pub async fn run(opt: Opt) -> Result<()> { source, dry_run, ignore_missing, - connect_opts, + mut connect_opts, target_version, } => { + connect_opts.populate_db_url(config)?; + migrate::revert( &source, &connect_opts, @@ -64,37 +74,56 @@ pub async fn run(opt: Opt) -> Result<()> { } MigrateCommand::Info { source, - connect_opts, - } => migrate::info(&source, &connect_opts).await?, + mut connect_opts, + } => { + connect_opts.populate_db_url(config)?; + + migrate::info(&source, &connect_opts).await? + }, MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?, }, Command::Database(database) => match database.command { - DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?, + DatabaseCommand::Create { mut connect_opts } => { + connect_opts.populate_db_url(config)?; + database::create(&connect_opts).await? + }, DatabaseCommand::Drop { confirmation, - connect_opts, + mut connect_opts, force, - } => database::drop(&connect_opts, !confirmation.yes, force).await?, + } => { + connect_opts.populate_db_url(config)?; + database::drop(&connect_opts, !confirmation.yes, force).await? + }, DatabaseCommand::Reset { confirmation, source, - connect_opts, + mut connect_opts, force, - } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?, + } => { + connect_opts.populate_db_url(config)?; + database::reset(&source, &connect_opts, !confirmation.yes, force).await? + }, DatabaseCommand::Setup { source, - connect_opts, - } => database::setup(&source, &connect_opts).await?, + mut connect_opts, + } => { + connect_opts.populate_db_url(config)?; + database::setup(&source, &connect_opts).await? + }, }, Command::Prepare { check, all, workspace, - connect_opts, + mut connect_opts, args, - } => prepare::run(check, all, workspace, connect_opts, args).await?, + } => { + connect_opts.populate_db_url(config)?; + prepare::run(check, all, workspace, connect_opts, args).await? + }, #[cfg(feature = "completions")] Command::Completions { shell } => completions::run(shell), @@ -122,7 +151,7 @@ where { sqlx::any::install_default_drivers(); - let db_url = opts.required_db_url()?; + let db_url = opts.expect_db_url()?; backoff::future::retry( backoff::ExponentialBackoffBuilder::new() @@ -147,3 +176,18 @@ where ) .await } + +async fn config_from_current_dir() -> anyhow::Result<&'static Config> { + // Tokio does file I/O on a background task anyway + tokio::task::spawn_blocking(|| { + let path = PathBuf::from("sqlx.toml"); + + if path.exists() { + eprintln!("Found `sqlx.toml` in current directory; reading..."); + } + + Config::read_with_or_default(move || Ok(path)) + }) + .await + .context("unexpected error loading config") +} diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index d5fe315234..a37fda181d 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,8 +1,10 @@ +use std::env; use std::ops::{Deref, Not}; - +use anyhow::Context; use clap::{Args, Parser}; #[cfg(feature = "completions")] use clap_complete::Shell; +use sqlx::config::Config; #[derive(Parser, Debug)] #[clap(version, about, author)] @@ -242,7 +244,7 @@ impl Deref for Source { #[derive(Args, Debug)] pub struct ConnectOpts { /// Location of the DB, by default will be read from the DATABASE_URL env var or `.env` files. - #[clap(long, short = 'D', env)] + #[clap(long, short = 'D')] pub database_url: Option, /// The maximum time, in seconds, to try connecting to the database server before @@ -266,12 +268,41 @@ pub struct ConnectOpts { impl ConnectOpts { /// Require a database URL to be provided, otherwise /// return an error. - pub fn required_db_url(&self) -> anyhow::Result<&str> { - self.database_url.as_deref().ok_or_else( - || anyhow::anyhow!( - "the `--database-url` option or the `DATABASE_URL` environment variable must be provided" - ) - ) + pub fn expect_db_url(&self) -> anyhow::Result<&str> { + self.database_url.as_deref().context("BUG: database_url not populated") + } + + /// Populate `database_url` from the environment, if not set. + pub fn populate_db_url(&mut self, config: &Config) -> anyhow::Result<()> { + if self.database_url.is_some() { + return Ok(()); + } + + let var = config.common.database_url_var(); + + let context = if var != "DATABASE_URL" { + " (`common.database-url-var` in `sqlx.toml`)" + } else { + "" + }; + + match env::var(var) { + Ok(url) => { + if !context.is_empty() { + eprintln!("Read database url from `{var}`{context}"); + } + + self.database_url = Some(url) + }, + Err(env::VarError::NotPresent) => { + anyhow::bail!("`--database-url` or `{var}`{context} must be set") + } + Err(env::VarError::NotUnicode(_)) => { + anyhow::bail!("`{var}`{context} is not valid UTF-8"); + } + } + + Ok(()) } } diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index b3afd9ea1b..02bde20f73 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -152,25 +152,7 @@ impl Config { /// ### Panics /// If the file exists but an unrecoverable error was encountered while parsing it. pub fn from_crate() -> &'static Self { - Self::try_from_crate().unwrap_or_else(|e| { - match e { - ConfigError::NotFound { path } => { - // Non-fatal - tracing::debug!("Not reading config, file {path:?} not found"); - CACHE.get_or_init(Config::default) - } - // FATAL ERRORS BELOW: - // In the case of migrations, - // we can't proceed with defaults as they may be completely wrong. - e @ ConfigError::ParseDisabled { .. } => { - // Only returned if the file exists but the feature is not enabled. - panic!("{e}") - } - e => { - panic!("failed to read sqlx config: {e}") - } - } - }) + Self::read_with_or_default(get_crate_path) } /// Get the cached config, or to read `$CARGO_MANIFEST_DIR/sqlx.toml`. @@ -179,11 +161,7 @@ impl Config { /// /// Errors if `CARGO_MANIFEST_DIR` is not set, or if the config file could not be read. pub fn try_from_crate() -> Result<&'static Self, ConfigError> { - Self::try_get_with(|| { - let mut path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR")?); - path.push("sqlx.toml"); - Ok(path) - }) + Self::try_read_with(get_crate_path) } /// Get the cached config, or attempt to read `sqlx.toml` from the current working directory. @@ -192,7 +170,7 @@ impl Config { /// /// Errors if the config file does not exist, or could not be read. pub fn try_from_current_dir() -> Result<&'static Self, ConfigError> { - Self::try_get_with(|| Ok("sqlx.toml".into())) + Self::try_read_with(|| Ok("sqlx.toml".into())) } /// Get the cached config, or attempt to read it from the path returned by the closure. @@ -200,7 +178,7 @@ impl Config { /// On success, the config is cached in a `static` and returned by future calls. /// /// Errors if the config file does not exist, or could not be read. - pub fn try_get_with( + pub fn try_read_with( make_path: impl FnOnce() -> Result, ) -> Result<&'static Self, ConfigError> { CACHE.get_or_try_init(|| { @@ -209,6 +187,36 @@ impl Config { }) } + /// Get the cached config, or attempt to read it from the path returned by the closure. + /// + /// On success, the config is cached in a `static` and returned by future calls. + /// + /// Returns `Config::default()` if the file does not exist. + pub fn read_with_or_default( + make_path: impl FnOnce() -> Result, + ) -> &'static Self { + CACHE.get_or_init(|| { + match make_path().and_then(Self::read_from) { + Ok(config) => config, + Err(ConfigError::NotFound { path }) => { + // Non-fatal + tracing::debug!("Not reading config, file {path:?} not found"); + Config::default() + } + // FATAL ERRORS BELOW: + // In the case of migrations, + // we can't proceed with defaults as they may be completely wrong. + Err(e @ ConfigError::ParseDisabled { .. }) => { + // Only returned if the file exists but the feature is not enabled. + panic!("{e}") + } + Err(e) => { + panic!("failed to read sqlx config: {e}") + } + } + }) + } + #[cfg(feature = "sqlx-toml")] fn read_from(path: PathBuf) -> Result { // The `toml` crate doesn't provide an incremental reader. @@ -238,3 +246,9 @@ impl Config { } } } + +fn get_crate_path() -> Result { + let mut path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR")?); + path.push("sqlx.toml"); + Ok(path) +} diff --git a/src/lib.rs b/src/lib.rs index 7c10b85213..aaa0e81952 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,6 +169,12 @@ pub mod prelude { #[cfg(feature = "_unstable-doc")] pub use sqlx_core::config; +// NOTE: APIs exported in this module are SemVer-exempt. +#[doc(hidden)] +pub mod _unstable { + pub use sqlx_core::config; +} + #[doc(hidden)] #[cfg_attr( all(feature = "chrono", feature = "time"), From 9d1bc64cedd22115e60479b85c8db9b80086aba9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 23 Sep 2024 02:06:46 -0700 Subject: [PATCH 10/30] feat: teach macros about `migrate.table-name`, `migrations-dir` --- sqlx-macros-core/src/migrate.rs | 39 ++++++++++++++++++++++--------- sqlx-macros-core/src/test_attr.rs | 14 +++++++---- sqlx-macros/src/lib.rs | 2 +- src/macros/mod.rs | 2 +- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index c9cf5b8eb1..56ac61405f 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -3,12 +3,15 @@ extern crate proc_macro; use std::path::{Path, PathBuf}; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::LitStr; - +use syn::spanned::Spanned; +use sqlx_core::config::Config; use sqlx_core::migrate::{Migration, MigrationType}; +pub const DEFAULT_PATH: &str = "./migrations"; + pub struct QuoteMigrationType(MigrationType); impl ToTokens for QuoteMigrationType { @@ -81,20 +84,26 @@ impl ToTokens for QuoteMigration { } } -pub fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result { - expand_migrator_from_dir(&dir.value(), dir.span()) +pub fn default_path(config: &Config) -> &str { + config.migrate.migrations_dir + .as_deref() + .unwrap_or(DEFAULT_PATH) } -pub(crate) fn expand_migrator_from_dir( - dir: &str, - err_span: proc_macro2::Span, -) -> crate::Result { - let path = crate::common::resolve_path(dir, err_span)?; +pub fn expand(path_arg: Option) -> crate::Result { + let config = Config::from_crate(); - expand_migrator(&path) + let path = match path_arg { + Some(path_arg) => crate::common::resolve_path(path_arg.value(), path_arg.span())?, + None => { + crate::common::resolve_path(default_path(config), Span::call_site()) + }? + }; + + expand_with_path(config, &path) } -pub(crate) fn expand_migrator(path: &Path) -> crate::Result { +pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result { let path = path.canonicalize().map_err(|e| { format!( "error canonicalizing migration directory {}: {e}", @@ -119,11 +128,19 @@ pub(crate) fn expand_migrator(path: &Path) -> crate::Result { proc_macro::tracked_path::path(path); } + let table_name = config.migrate.table_name + .as_deref() + .map_or_else( + || quote! {}, + |name| quote! { table_name: Some(::std::borrow::Cow::Borrowed(#name)), } + ); + Ok(quote! { ::sqlx::migrate::Migrator { migrations: ::std::borrow::Cow::Borrowed(&[ #(#migrations),* ]), + #table_name ..::sqlx::migrate::Migrator::DEFAULT } }) diff --git a/sqlx-macros-core/src/test_attr.rs b/sqlx-macros-core/src/test_attr.rs index d7c6eb0486..403b6e7de3 100644 --- a/sqlx-macros-core/src/test_attr.rs +++ b/sqlx-macros-core/src/test_attr.rs @@ -77,6 +77,8 @@ fn expand_simple(input: syn::ItemFn) -> TokenStream { #[cfg(feature = "migrate")] fn expand_advanced(args: AttributeArgs, input: syn::ItemFn) -> crate::Result { + let config = sqlx_core::config::Config::from_crate(); + let ret = &input.sig.output; let name = &input.sig.ident; let inputs = &input.sig.inputs; @@ -143,15 +145,17 @@ fn expand_advanced(args: AttributeArgs, input: syn::ItemFn) -> crate::Result { - let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?; + let migrator = crate::migrate::expand(Some(path))?; quote! { args.migrator(&#migrator); } } MigrationsOpt::InferredPath if !inputs.is_empty() => { - let migrations_path = - crate::common::resolve_path("./migrations", proc_macro2::Span::call_site())?; + let path = crate::migrate::default_path(config); + + let resolved_path = + crate::common::resolve_path(path, proc_macro2::Span::call_site())?; - if migrations_path.is_dir() { - let migrator = crate::migrate::expand_migrator(&migrations_path)?; + if resolved_path.is_dir() { + let migrator = crate::migrate::expand_with_path(config, &resolved_path)?; quote! { args.migrator(&#migrator); } } else { quote! {} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 987794acbc..f527f5d2fd 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -69,7 +69,7 @@ pub fn migrate(input: TokenStream) -> TokenStream { use syn::LitStr; let input = syn::parse_macro_input!(input as LitStr); - match migrate::expand_migrator_from_lit_dir(input) { + match migrate::expand(input) { Ok(ts) => ts.into(), Err(e) => { if let Some(parse_err) = e.downcast_ref::() { diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 7f8ff747f9..c9602b55c5 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -814,6 +814,6 @@ macro_rules! migrate { }}; () => {{ - $crate::sqlx_macros::migrate!("./migrations") + $crate::sqlx_macros::migrate!() }}; } From ba7740d8e5dc125028409d7cdc174ae87e47e8e0 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 23 Sep 2024 02:15:14 -0700 Subject: [PATCH 11/30] feat: teach macros about `migrate.ignored-chars` --- sqlx-macros-core/src/migrate.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index 56ac61405f..976cb181bb 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -8,7 +8,7 @@ use quote::{quote, ToTokens, TokenStreamExt}; use syn::LitStr; use syn::spanned::Spanned; use sqlx_core::config::Config; -use sqlx_core::migrate::{Migration, MigrationType}; +use sqlx_core::migrate::{Migration, MigrationType, ResolveConfig}; pub const DEFAULT_PATH: &str = "./migrations"; @@ -111,8 +111,11 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result Date: Sat, 5 Oct 2024 14:03:25 -0700 Subject: [PATCH 12/30] chore: delete unused source file `sqlx-cli/src/migration.rs` --- sqlx-cli/src/migration.rs | 187 -------------------------------------- 1 file changed, 187 deletions(-) delete mode 100644 sqlx-cli/src/migration.rs diff --git a/sqlx-cli/src/migration.rs b/sqlx-cli/src/migration.rs deleted file mode 100644 index 2ed8f94495..0000000000 --- a/sqlx-cli/src/migration.rs +++ /dev/null @@ -1,187 +0,0 @@ -use anyhow::{bail, Context}; -use console::style; -use std::fs::{self, File}; -use std::io::{Read, Write}; - -const MIGRATION_FOLDER: &str = "migrations"; - -pub struct Migration { - pub name: String, - pub sql: String, -} - -pub fn add_file(name: &str) -> anyhow::Result<()> { - use chrono::prelude::*; - use std::path::PathBuf; - - fs::create_dir_all(MIGRATION_FOLDER).context("Unable to create migrations directory")?; - - let dt = Utc::now(); - let mut file_name = dt.format("%Y-%m-%d_%H-%M-%S").to_string(); - file_name.push_str("_"); - file_name.push_str(name); - file_name.push_str(".sql"); - - let mut path = PathBuf::new(); - path.push(MIGRATION_FOLDER); - path.push(&file_name); - - let mut file = File::create(path).context("Failed to create file")?; - file.write_all(b"-- Add migration script here") - .context("Could not write to file")?; - - println!("Created migration: '{file_name}'"); - Ok(()) -} - -pub async fn run() -> anyhow::Result<()> { - let migrator = crate::migrator::get()?; - - if !migrator.can_migrate_database() { - bail!( - "Database migrations not supported for {}", - migrator.database_type() - ); - } - - migrator.create_migration_table().await?; - - let migrations = load_migrations()?; - - for mig in migrations.iter() { - let mut tx = migrator.begin_migration().await?; - - if tx.check_if_applied(&mig.name).await? { - println!("Already applied migration: '{}'", mig.name); - continue; - } - println!("Applying migration: '{}'", mig.name); - - tx.execute_migration(&mig.sql) - .await - .with_context(|| format!("Failed to run migration {:?}", &mig.name))?; - - tx.save_applied_migration(&mig.name) - .await - .context("Failed to insert migration")?; - - tx.commit().await.context("Failed")?; - } - - Ok(()) -} - -pub async fn list() -> anyhow::Result<()> { - let migrator = crate::migrator::get()?; - - if !migrator.can_migrate_database() { - bail!( - "Database migrations not supported for {}", - migrator.database_type() - ); - } - - let file_migrations = load_migrations()?; - - if migrator - .check_if_database_exists(&migrator.get_database_name()?) - .await? - { - let applied_migrations = migrator.get_migrations().await.unwrap_or_else(|_| { - println!("Could not retrieve data from migration table"); - Vec::new() - }); - - let mut width = 0; - for mig in file_migrations.iter() { - width = std::cmp::max(width, mig.name.len()); - } - for mig in file_migrations.iter() { - let status = if applied_migrations - .iter() - .find(|&m| mig.name == *m) - .is_some() - { - style("Applied").green() - } else { - style("Not Applied").yellow() - }; - - println!("{:width$}\t{}", mig.name, status, width = width); - } - - let orphans = check_for_orphans(file_migrations, applied_migrations); - - if let Some(orphans) = orphans { - println!("\nFound migrations applied in the database that does not have a corresponding migration file:"); - for name in orphans { - println!("{:width$}\t{}", name, style("Orphan").red(), width = width); - } - } - } else { - println!("No database found, listing migrations"); - - for mig in file_migrations { - println!("{}", mig.name); - } - } - - Ok(()) -} - -fn load_migrations() -> anyhow::Result> { - let entries = fs::read_dir(&MIGRATION_FOLDER).context("Could not find 'migrations' dir")?; - - let mut migrations = Vec::new(); - - for e in entries { - if let Ok(e) = e { - if let Ok(meta) = e.metadata() { - if !meta.is_file() { - continue; - } - - if let Some(ext) = e.path().extension() { - if ext != "sql" { - println!("Wrong ext: {ext:?}"); - continue; - } - } else { - continue; - } - - let mut file = File::open(e.path()) - .with_context(|| format!("Failed to open: '{:?}'", e.file_name()))?; - let mut contents = String::new(); - file.read_to_string(&mut contents) - .with_context(|| format!("Failed to read: '{:?}'", e.file_name()))?; - - migrations.push(Migration { - name: e.file_name().to_str().unwrap().to_string(), - sql: contents, - }); - } - } - } - - migrations.sort_by(|a, b| a.name.partial_cmp(&b.name).unwrap()); - - Ok(migrations) -} - -fn check_for_orphans( - file_migrations: Vec, - applied_migrations: Vec, -) -> Option> { - let orphans: Vec = applied_migrations - .iter() - .filter(|m| !file_migrations.iter().any(|fm| fm.name == **m)) - .cloned() - .collect(); - - if orphans.len() > 0 { - Some(orphans) - } else { - None - } -} From 367f2cca98d376adbd5ead2dbc9b85684e1ce067 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 5 Oct 2024 15:21:32 -0700 Subject: [PATCH 13/30] feat: teach `sqlx-cli` about `migrate.defaults` --- sqlx-cli/src/lib.rs | 15 +- sqlx-cli/src/migrate.rs | 110 +++----------- sqlx-cli/src/opt.rs | 182 +++++++++++++++++++++--- sqlx-core/src/migrate/migration_type.rs | 3 +- 4 files changed, 186 insertions(+), 124 deletions(-) diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index d08b949c78..5d1269f34e 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,5 +1,5 @@ use std::io; -use std::path::{Path, PathBuf}; +use std::path::{PathBuf}; use std::time::Duration; use anyhow::{Context, Result}; @@ -21,21 +21,14 @@ mod prepare; pub use crate::opt::Opt; -pub use sqlx::_unstable::config; -use crate::config::Config; +pub use sqlx::_unstable::config::{self, Config}; pub async fn run(opt: Opt) -> Result<()> { - let config = config_from_current_dir()?; + let config = config_from_current_dir().await?; match opt.command { Command::Migrate(migrate) => match migrate.command { - MigrateCommand::Add { - source, - description, - reversible, - sequential, - timestamp, - } => migrate::add(&source, &description, reversible, sequential, timestamp).await?, + MigrateCommand::Add(opts)=> migrate::add(config, opts).await?, MigrateCommand::Run { source, dry_run, diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index e00f6de651..76ad7dfb97 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -1,6 +1,5 @@ -use crate::opt::ConnectOpts; +use crate::opt::{AddMigrationOpts, ConnectOpts}; use anyhow::{bail, Context}; -use chrono::Utc; use console::style; use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator}; use sqlx::Connection; @@ -10,6 +9,7 @@ use std::fmt::Write; use std::fs::{self, File}; use std::path::Path; use std::time::Duration; +use crate::config::Config; fn create_file( migration_source: &str, @@ -37,116 +37,46 @@ fn create_file( Ok(()) } -enum MigrationOrdering { - Timestamp(String), - Sequential(String), -} - -impl MigrationOrdering { - fn timestamp() -> MigrationOrdering { - Self::Timestamp(Utc::now().format("%Y%m%d%H%M%S").to_string()) - } - - fn sequential(version: i64) -> MigrationOrdering { - Self::Sequential(format!("{version:04}")) - } - - fn file_prefix(&self) -> &str { - match self { - MigrationOrdering::Timestamp(prefix) => prefix, - MigrationOrdering::Sequential(prefix) => prefix, - } - } - - fn infer(sequential: bool, timestamp: bool, migrator: &Migrator) -> Self { - match (timestamp, sequential) { - (true, true) => panic!("Impossible to specify both timestamp and sequential mode"), - (true, false) => MigrationOrdering::timestamp(), - (false, true) => MigrationOrdering::sequential( - migrator - .iter() - .last() - .map_or(1, |last_migration| last_migration.version + 1), - ), - (false, false) => { - // inferring the naming scheme - let migrations = migrator - .iter() - .filter(|migration| migration.migration_type.is_up_migration()) - .rev() - .take(2) - .collect::>(); - if let [last, pre_last] = &migrations[..] { - // there are at least two migrations, compare the last twothere's only one existing migration - if last.version - pre_last.version == 1 { - // their version numbers differ by 1, infer sequential - MigrationOrdering::sequential(last.version + 1) - } else { - MigrationOrdering::timestamp() - } - } else if let [last] = &migrations[..] { - // there is only one existing migration - if last.version == 0 || last.version == 1 { - // infer sequential if the version number is 0 or 1 - MigrationOrdering::sequential(last.version + 1) - } else { - MigrationOrdering::timestamp() - } - } else { - MigrationOrdering::timestamp() - } - } - } - } -} - pub async fn add( - migration_source: &str, - description: &str, - reversible: bool, - sequential: bool, - timestamp: bool, + config: &Config, + opts: AddMigrationOpts, ) -> anyhow::Result<()> { - fs::create_dir_all(migration_source).context("Unable to create migrations directory")?; + fs::create_dir_all(&opts.source).context("Unable to create migrations directory")?; - let migrator = Migrator::new(Path::new(migration_source)).await?; - // Type of newly created migration will be the same as the first one - // or reversible flag if this is the first migration - let migration_type = MigrationType::infer(&migrator, reversible); + let migrator = Migrator::new(opts.source.as_ref()).await?; - let ordering = MigrationOrdering::infer(sequential, timestamp, &migrator); - let file_prefix = ordering.file_prefix(); + let version_prefix = opts.version_prefix(config, &migrator); - if migration_type.is_reversible() { + if opts.reversible(config, &migrator) { create_file( - migration_source, - file_prefix, - description, + &opts.source, + &version_prefix, + &opts.description, MigrationType::ReversibleUp, )?; create_file( - migration_source, - file_prefix, - description, + &opts.source, + &version_prefix, + &opts.description, MigrationType::ReversibleDown, )?; } else { create_file( - migration_source, - file_prefix, - description, + &opts.source, + &version_prefix, + &opts.description, MigrationType::Simple, )?; } // if the migrations directory is empty - let has_existing_migrations = fs::read_dir(migration_source) + let has_existing_migrations = fs::read_dir(&opts.source) .map(|mut dir| dir.next().is_some()) .unwrap_or(false); if !has_existing_migrations { - let quoted_source = if migration_source != "migrations" { - format!("{migration_source:?}") + let quoted_source = if *opts.source != "migrations" { + format!("{:?}", *opts.source) } else { "".to_string() }; diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index a37fda181d..6200f4dbb2 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,10 +1,14 @@ use std::env; use std::ops::{Deref, Not}; +use std::path::Path; use anyhow::Context; +use chrono::Utc; use clap::{Args, Parser}; #[cfg(feature = "completions")] use clap_complete::Shell; -use sqlx::config::Config; +use crate::config::Config; +use sqlx::migrate::Migrator; +use crate::config::migrate::{DefaultMigrationType, DefaultVersioning}; #[derive(Parser, Debug)] #[clap(version, about, author)] @@ -125,8 +129,55 @@ pub struct MigrateOpt { pub enum MigrateCommand { /// Create a new migration with the given description. /// + /// -------------------------------- + /// + /// Migrations may either be simple, or reversible. + /// + /// Reversible migrations can be reverted with `sqlx migrate revert`, simple migrations cannot. + /// + /// Reversible migrations are created as a pair of two files with the same filename but + /// extensions `.up.sql` and `.down.sql` for the up-migration and down-migration, respectively. + /// + /// The up-migration should contain the commands to be used when applying the migration, + /// while the down-migration should contain the commands to reverse the changes made by the + /// up-migration. + /// + /// When writing down-migrations, care should be taken to ensure that they + /// do not leave the database in an inconsistent state. + /// + /// Simple migrations have just `.sql` for their extension and represent an up-migration only. + /// + /// Note that reverting a migration is **destructive** and will likely result in data loss. + /// Reverting a migration will not restore any data discarded by commands in the up-migration. + /// + /// It is recommended to always back up the database before running migrations. + /// + /// -------------------------------- + /// + /// For convenience, this command attempts to detect if reversible migrations are in-use. + /// + /// If the latest existing migration is reversible, the new migration will also be reversible. + /// + /// Otherwise, a simple migration is created. + /// + /// This behavior can be overridden by `--simple` or `--reversible`, respectively. + /// + /// The default type to use can also be set in `sqlx.toml`. + /// + /// -------------------------------- + /// /// A version number will be automatically assigned to the migration. /// + /// Migrations are applied in ascending order by version number. + /// Version numbers do not need to be strictly consecutive. + /// + /// The migration process will abort if SQLx encounters a migration with a version number + /// less than _any_ previously applied migration. + /// + /// Migrations should only be created with increasing version number. + /// + /// -------------------------------- + /// /// For convenience, this command will attempt to detect if sequential versioning is in use, /// and if so, continue the sequence. /// @@ -136,28 +187,12 @@ pub enum MigrateCommand { /// /// * only one migration exists and its version number is either 0 or 1. /// - /// Otherwise timestamp versioning is assumed. + /// Otherwise, timestamp versioning (`YYYYMMDDHHMMSS`) is assumed. /// - /// This behavior can overridden by `--sequential` or `--timestamp`, respectively. - Add { - description: String, - - #[clap(flatten)] - source: Source, - - /// If true, creates a pair of up and down migration files with same version - /// else creates a single sql file - #[clap(short)] - reversible: bool, - - /// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`. - #[clap(short, long)] - timestamp: bool, - - /// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`. - #[clap(short, long, conflicts_with = "timestamp")] - sequential: bool, - }, + /// This behavior can be overridden by `--timestamp` or `--sequential`, respectively. + /// + /// The default versioning to use can also be set in `sqlx.toml`. + Add(AddMigrationOpts), /// Run all pending migrations. Run { @@ -224,6 +259,34 @@ pub enum MigrateCommand { }, } +#[derive(Args, Debug)] +pub struct AddMigrationOpts { + pub description: String, + + #[clap(flatten)] + pub source: Source, + + /// If set, create an up-migration only. Conflicts with `--reversible`. + #[clap(long, conflicts_with = "reversible")] + simple: bool, + + /// If set, create a pair of up and down migration files with same version. + /// + /// Conflicts with `--simple`. + #[clap(short, long, conflicts_with = "simple")] + reversible: bool, + + /// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`. + /// + /// Timestamp format: `YYYYMMDDHHMMSS` + #[clap(short, long, conflicts_with = "sequential")] + timestamp: bool, + + /// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`. + #[clap(short, long, conflicts_with = "timestamp")] + sequential: bool, +} + /// Argument for the migration scripts source. #[derive(Args, Debug)] pub struct Source { @@ -240,6 +303,12 @@ impl Deref for Source { } } +impl AsRef for Source { + fn as_ref(&self) -> &Path { + Path::new(&self.source) + } +} + /// Argument for the database URL. #[derive(Args, Debug)] pub struct ConnectOpts { @@ -338,3 +407,72 @@ impl Not for IgnoreMissing { !self.ignore_missing } } + +impl AddMigrationOpts { + pub fn reversible(&self, config: &Config, migrator: &Migrator) -> bool { + if self.reversible { return true; } + if self.simple { return false; } + + match config.migrate.defaults.migration_type { + DefaultMigrationType::Inferred => { + migrator + .iter() + .last() + .is_some_and(|m| m.migration_type.is_reversible()) + } + DefaultMigrationType::Simple => { + false + } + DefaultMigrationType::Reversible => { + true + } + } + } + + pub fn version_prefix(&self, config: &Config, migrator: &Migrator) -> String { + let default_versioning = &config.migrate.defaults.migration_versioning; + + if self.timestamp || matches!(default_versioning, DefaultVersioning::Timestamp) { + return next_timestamp(); + } + + if self.sequential || matches!(default_versioning, DefaultVersioning::Sequential) { + return next_sequential(migrator) + .unwrap_or_else(|| fmt_sequential(1)); + } + + next_sequential(migrator).unwrap_or_else(next_timestamp) + } +} + +fn next_timestamp() -> String { + Utc::now().format("%Y%m%d%H%M%S").to_string() +} + +fn next_sequential(migrator: &Migrator) -> Option { + let next_version = migrator + .migrations + .windows(2) + .last() + .and_then(|migrations| { + match migrations { + [previous, latest] => { + // If the latest two versions differ by 1, infer sequential. + (latest.version - previous.version == 1) + .then_some(latest.version + 1) + }, + [latest] => { + // If only one migration exists and its version is 0 or 1, infer sequential + matches!(latest.version, 0 | 1) + .then_some(latest.version + 1) + } + _ => unreachable!(), + } + }); + + next_version.map(fmt_sequential) +} + +fn fmt_sequential(version: i64) -> String { + format!("{version:04}") +} diff --git a/sqlx-core/src/migrate/migration_type.rs b/sqlx-core/src/migrate/migration_type.rs index de2b019307..350ddb3f27 100644 --- a/sqlx-core/src/migrate/migration_type.rs +++ b/sqlx-core/src/migrate/migration_type.rs @@ -74,8 +74,9 @@ impl MigrationType { } } + #[deprecated = "unused"] pub fn infer(migrator: &Migrator, reversible: bool) -> MigrationType { - match migrator.iter().next() { + match migrator.iter().last() { Some(first_migration) => first_migration.migration_type, None => { if reversible { From 1ff6a8a950565c668f415ec8f8806b907f62c7f9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 15 Jan 2025 10:31:03 -0800 Subject: [PATCH 14/30] feat: teach `sqlx-cli` about `migrate.migrations-dir` --- Cargo.lock | 65 +++++++++++++--- sqlx-cli/src/database.rs | 13 ++-- sqlx-cli/src/lib.rs | 10 ++- sqlx-cli/src/migrate.rs | 99 +++++++++++++----------- sqlx-cli/src/opt.rs | 41 +++++----- sqlx-core/src/config/migrate.rs | 15 +++- sqlx-macros-core/src/migrate.rs | 3 +- sqlx-macros/src/lib.rs | 2 +- sqlx-postgres/src/connection/describe.rs | 4 +- 9 files changed, 162 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 68f3c84cd8..0e0b156bdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1990,12 +1990,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -2774,7 +2774,7 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284" dependencies = [ - "toml_edit", + "toml_edit 0.21.1", ] [[package]] @@ -3306,6 +3306,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -3559,7 +3568,7 @@ dependencies = [ "futures-util", "hashbrown 0.15.2", "hashlink", - "indexmap 2.2.5", + "indexmap 2.7.0", "ipnetwork", "log", "mac_address", @@ -3581,6 +3590,7 @@ dependencies = [ "time", "tokio", "tokio-stream", + "toml", "tracing", "url", "uuid", @@ -4186,11 +4196,26 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit 0.22.22", +] + [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] [[package]] name = "toml_edit" @@ -4198,9 +4223,22 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ - "indexmap 2.2.5", + "indexmap 2.7.0", "toml_datetime", - "winnow", + "winnow 0.5.40", +] + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap 2.7.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow 0.6.22", ] [[package]] @@ -4791,6 +4829,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "0.6.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" +dependencies = [ + "memchr", +] + [[package]] name = "write16" version = "1.0.0" diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 53834c111e..1fd8bcc534 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,5 +1,5 @@ -use crate::migrate; -use crate::opt::ConnectOpts; +use crate::{migrate, Config}; +use crate::opt::{ConnectOpts, MigrationSourceOpt}; use console::style; use promptly::{prompt, ReadlineError}; use sqlx::any::Any; @@ -44,18 +44,19 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any } pub async fn reset( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, confirm: bool, force: bool, ) -> anyhow::Result<()> { drop(connect_opts, confirm, force).await?; - setup(migration_source, connect_opts).await + setup(config, migration_source, connect_opts).await } -pub async fn setup(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> { +pub async fn setup(config: &Config, migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts) -> anyhow::Result<()> { create(connect_opts).await?; - migrate::run(migration_source, connect_opts, false, false, None).await + migrate::run(config, migration_source, connect_opts, false, false, None).await } fn ask_to_continue_drop(db_url: &str) -> bool { diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 5d1269f34e..63257f541f 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -39,6 +39,7 @@ pub async fn run(opt: Opt) -> Result<()> { connect_opts.populate_db_url(config)?; migrate::run( + config, &source, &connect_opts, dry_run, @@ -57,6 +58,7 @@ pub async fn run(opt: Opt) -> Result<()> { connect_opts.populate_db_url(config)?; migrate::revert( + config, &source, &connect_opts, dry_run, @@ -71,9 +73,9 @@ pub async fn run(opt: Opt) -> Result<()> { } => { connect_opts.populate_db_url(config)?; - migrate::info(&source, &connect_opts).await? + migrate::info(config, &source, &connect_opts).await? }, - MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?, + MigrateCommand::BuildScript { source, force } => migrate::build_script(config, &source, force)?, }, Command::Database(database) => match database.command { @@ -96,14 +98,14 @@ pub async fn run(opt: Opt) -> Result<()> { force, } => { connect_opts.populate_db_url(config)?; - database::reset(&source, &connect_opts, !confirmation.yes, force).await? + database::reset(config, &source, &connect_opts, !confirmation.yes, force).await? }, DatabaseCommand::Setup { source, mut connect_opts, } => { connect_opts.populate_db_url(config)?; - database::setup(&source, &connect_opts).await? + database::setup(config, &source, &connect_opts).await? }, }, diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index 76ad7dfb97..aabee2928f 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -1,7 +1,7 @@ -use crate::opt::{AddMigrationOpts, ConnectOpts}; +use crate::opt::{AddMigrationOpts, ConnectOpts, MigrationSourceOpt}; use anyhow::{bail, Context}; use console::style; -use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator}; +use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator, ResolveWith}; use sqlx::Connection; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -11,58 +11,34 @@ use std::path::Path; use std::time::Duration; use crate::config::Config; -fn create_file( - migration_source: &str, - file_prefix: &str, - description: &str, - migration_type: MigrationType, -) -> anyhow::Result<()> { - use std::path::PathBuf; - - let mut file_name = file_prefix.to_string(); - file_name.push('_'); - file_name.push_str(&description.replace(' ', "_")); - file_name.push_str(migration_type.suffix()); - - let mut path = PathBuf::new(); - path.push(migration_source); - path.push(&file_name); - - println!("Creating {}", style(path.display()).cyan()); - - let mut file = File::create(&path).context("Failed to create migration file")?; - - std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?; - - Ok(()) -} - pub async fn add( config: &Config, opts: AddMigrationOpts, ) -> anyhow::Result<()> { - fs::create_dir_all(&opts.source).context("Unable to create migrations directory")?; + let source = opts.source.resolve(config); + + fs::create_dir_all(source).context("Unable to create migrations directory")?; - let migrator = Migrator::new(opts.source.as_ref()).await?; + let migrator = Migrator::new(Path::new(source)).await?; let version_prefix = opts.version_prefix(config, &migrator); if opts.reversible(config, &migrator) { create_file( - &opts.source, + source, &version_prefix, &opts.description, MigrationType::ReversibleUp, )?; create_file( - &opts.source, + source, &version_prefix, &opts.description, MigrationType::ReversibleDown, )?; } else { create_file( - &opts.source, + source, &version_prefix, &opts.description, MigrationType::Simple, @@ -70,13 +46,13 @@ pub async fn add( } // if the migrations directory is empty - let has_existing_migrations = fs::read_dir(&opts.source) + let has_existing_migrations = fs::read_dir(source) .map(|mut dir| dir.next().is_some()) .unwrap_or(false); if !has_existing_migrations { - let quoted_source = if *opts.source != "migrations" { - format!("{:?}", *opts.source) + let quoted_source = if opts.source.source.is_some() { + format!("{source:?}") } else { "".to_string() }; @@ -114,6 +90,32 @@ See: https://docs.rs/sqlx/{version}/sqlx/macro.migrate.html Ok(()) } +fn create_file( + migration_source: &str, + file_prefix: &str, + description: &str, + migration_type: MigrationType, +) -> anyhow::Result<()> { + use std::path::PathBuf; + + let mut file_name = file_prefix.to_string(); + file_name.push('_'); + file_name.push_str(&description.replace(' ', "_")); + file_name.push_str(migration_type.suffix()); + + let mut path = PathBuf::new(); + path.push(migration_source); + path.push(&file_name); + + println!("Creating {}", style(path.display()).cyan()); + + let mut file = File::create(&path).context("Failed to create migration file")?; + + std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?; + + Ok(()) +} + fn short_checksum(checksum: &[u8]) -> String { let mut s = String::with_capacity(checksum.len() * 2); for b in checksum { @@ -122,8 +124,10 @@ fn short_checksum(checksum: &[u8]) -> String { s } -pub async fn info(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; +pub async fn info(config: &Config, migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts) -> anyhow::Result<()> { + let source = migration_source.resolve(config); + + let migrator = Migrator::new(ResolveWith(Path::new(source), config.migrate.to_resolve_config())).await?; let mut conn = crate::connect(connect_opts).await?; conn.ensure_migrations_table().await?; @@ -202,13 +206,16 @@ fn validate_applied_migrations( } pub async fn run( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, dry_run: bool, ignore_missing: bool, target_version: Option, ) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; + let source = migration_source.resolve(config); + + let migrator = Migrator::new(Path::new(source)).await?; if let Some(target_version) = target_version { if !migrator.version_exists(target_version) { bail!(MigrateError::VersionNotPresent(target_version)); @@ -295,13 +302,15 @@ pub async fn run( } pub async fn revert( - migration_source: &str, + config: &Config, + migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts, dry_run: bool, ignore_missing: bool, target_version: Option, ) -> anyhow::Result<()> { - let migrator = Migrator::new(Path::new(migration_source)).await?; + let source = migration_source.resolve(config); + let migrator = Migrator::new(Path::new(source)).await?; if let Some(target_version) = target_version { if target_version != 0 && !migrator.version_exists(target_version) { bail!(MigrateError::VersionNotPresent(target_version)); @@ -388,7 +397,9 @@ pub async fn revert( Ok(()) } -pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { +pub fn build_script(config: &Config, migration_source: &MigrationSourceOpt, force: bool) -> anyhow::Result<()> { + let source = migration_source.resolve(config); + anyhow::ensure!( Path::new("Cargo.toml").exists(), "must be run in a Cargo project root" @@ -403,7 +414,7 @@ pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { r#"// generated by `sqlx migrate build-script` fn main() {{ // trigger recompilation when a new migration is added - println!("cargo:rerun-if-changed={migration_source}"); + println!("cargo:rerun-if-changed={source}"); }} "#, ); diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 6200f4dbb2..0b72af6594 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,6 +1,5 @@ use std::env; use std::ops::{Deref, Not}; -use std::path::Path; use anyhow::Context; use chrono::Utc; use clap::{Args, Parser}; @@ -98,7 +97,7 @@ pub enum DatabaseCommand { confirmation: Confirmation, #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -111,7 +110,7 @@ pub enum DatabaseCommand { /// Creates the database specified in your DATABASE_URL and runs any pending migrations. Setup { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -197,7 +196,7 @@ pub enum MigrateCommand { /// Run all pending migrations. Run { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, /// List all the migrations to be run without applying #[clap(long)] @@ -218,7 +217,7 @@ pub enum MigrateCommand { /// Revert the latest migration with a down file. Revert { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, /// List the migration to be reverted without applying #[clap(long)] @@ -240,7 +239,7 @@ pub enum MigrateCommand { /// List all available migrations. Info { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, #[clap(flatten)] connect_opts: ConnectOpts, @@ -251,7 +250,7 @@ pub enum MigrateCommand { /// Must be run in a Cargo project root. BuildScript { #[clap(flatten)] - source: Source, + source: MigrationSourceOpt, /// Overwrite the build script if it already exists. #[clap(long)] @@ -264,7 +263,7 @@ pub struct AddMigrationOpts { pub description: String, #[clap(flatten)] - pub source: Source, + pub source: MigrationSourceOpt, /// If set, create an up-migration only. Conflicts with `--reversible`. #[clap(long, conflicts_with = "reversible")] @@ -289,23 +288,21 @@ pub struct AddMigrationOpts { /// Argument for the migration scripts source. #[derive(Args, Debug)] -pub struct Source { +pub struct MigrationSourceOpt { /// Path to folder containing migrations. - #[clap(long, default_value = "migrations")] - source: String, -} - -impl Deref for Source { - type Target = String; - - fn deref(&self) -> &Self::Target { - &self.source - } + /// + /// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`. + #[clap(long)] + pub source: Option, } -impl AsRef for Source { - fn as_ref(&self) -> &Path { - Path::new(&self.source) +impl MigrationSourceOpt { + pub fn resolve<'a>(&'a self, config: &'a Config) -> &'a str { + if let Some(source) = &self.source { + return source; + } + + config.migrate.migrations_dir() } } diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index 64529f9f02..666ed5bf92 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -85,7 +85,7 @@ pub struct Config { /// To make your migrations amenable to reformatting, you may wish to tell SQLx to ignore /// _all_ whitespace characters in migrations. /// - /// ##### Warning: Beware Syntatically Significant Whitespace! + /// ##### Warning: Beware Syntactically Significant Whitespace! /// If your migrations use string literals or quoted identifiers which contain whitespace, /// this configuration will cause the migration machinery to ignore some changes to these. /// This may result in a mismatch between the development and production versions of @@ -179,3 +179,16 @@ pub enum DefaultVersioning { /// Use sequential integers for migration versions. Sequential, } + +#[cfg(feature = "migrate")] +impl Config { + pub fn migrations_dir(&self) -> &str { + self.migrations_dir.as_deref().unwrap_or("migrations") + } + + pub fn to_resolve_config(&self) -> crate::migrate::ResolveConfig { + let mut config = crate::migrate::ResolveConfig::new(); + config.ignore_chars(self.ignored_chars.iter().copied()); + config + } +} \ No newline at end of file diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index 976cb181bb..0ae2eaebda 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -111,8 +111,7 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result TokenStream { pub fn migrate(input: TokenStream) -> TokenStream { use syn::LitStr; - let input = syn::parse_macro_input!(input as LitStr); + let input = syn::parse_macro_input!(input as Option); match migrate::expand(input) { Ok(ts) => ts.into(), Err(e) => { diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 4decdde5dd..5b6a2aa09c 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -209,7 +209,8 @@ impl PgConnection { should_fetch: bool, ) -> Result { if let Some(origin) = - self.cache_table_to_column_names + self.inner + .cache_table_to_column_names .get(&relation_id) .and_then(|table_columns| { let column_name = table_columns.columns.get(&attribute_no).cloned()?; @@ -245,6 +246,7 @@ impl PgConnection { }; let table_columns = self + .inner .cache_table_to_column_names .entry(relation_id) .or_insert_with(|| TableColumns { From 45c0b85b4c478f3aa44b8146cafa64dbed7307b1 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 22 Jan 2025 14:24:18 -0800 Subject: [PATCH 15/30] feat: teach `sqlx-cli` about `migrate.table-name` --- sqlx-cli/src/migrate.rs | 20 ++++----- sqlx-cli/tests/common/mod.rs | 5 ++- sqlx-core/src/any/migrate.rs | 32 +++++++------- sqlx-core/src/config/migrate.rs | 4 ++ sqlx-core/src/migrate/migrate.rs | 22 +++++----- sqlx-core/src/migrate/migrator.rs | 72 ++++++++++++++++++------------- sqlx-core/src/testing/mod.rs | 2 +- sqlx-mysql/src/migrate.rs | 59 +++++++++++++------------ sqlx-postgres/src/migrate.rs | 56 ++++++++++++------------ sqlx-sqlite/src/migrate.rs | 51 +++++++++++----------- 10 files changed, 172 insertions(+), 151 deletions(-) diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index aabee2928f..9e0119682e 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -130,10 +130,10 @@ pub async fn info(config: &Config, migration_source: &MigrationSourceOpt, connec let migrator = Migrator::new(ResolveWith(Path::new(source), config.migrate.to_resolve_config())).await?; let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(config.migrate.table_name()).await?; let applied_migrations: HashMap<_, _> = conn - .list_applied_migrations() + .list_applied_migrations(config.migrate.table_name()) .await? .into_iter() .map(|m| (m.version, m)) @@ -224,14 +224,14 @@ pub async fn run( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(config.migrate.table_name()).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(config.migrate.table_name()).await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -269,7 +269,7 @@ pub async fn run( let elapsed = if dry_run || skip { Duration::new(0, 0) } else { - conn.apply(migration).await? + conn.apply(config.migrate.table_name(), migration).await? }; let text = if skip { "Skipped" @@ -319,14 +319,14 @@ pub async fn revert( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(config.migrate.table_name()).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(config.migrate.table_name()).await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -360,7 +360,7 @@ pub async fn revert( let elapsed = if dry_run || skip { Duration::new(0, 0) } else { - conn.revert(migration).await? + conn.revert(config.migrate.table_name(), migration).await? }; let text = if skip { "Skipped" diff --git a/sqlx-cli/tests/common/mod.rs b/sqlx-cli/tests/common/mod.rs index 43c0dbc1e1..bb58554f33 100644 --- a/sqlx-cli/tests/common/mod.rs +++ b/sqlx-cli/tests/common/mod.rs @@ -6,10 +6,12 @@ use std::{ fs::remove_file, path::{Path, PathBuf}, }; +use sqlx::_unstable::config::Config; pub struct TestDatabase { file_path: PathBuf, migrations: String, + config: &'static Config, } impl TestDatabase { @@ -19,6 +21,7 @@ impl TestDatabase { let ret = Self { file_path, migrations: String::from(migrations_path.to_str().unwrap()), + config: Config::from_crate(), }; Command::cargo_bin("cargo-sqlx") .unwrap() @@ -77,7 +80,7 @@ impl TestDatabase { let mut conn = SqliteConnection::connect(&self.connection_string()) .await .unwrap(); - conn.list_applied_migrations() + conn.list_applied_migrations(self.config.migrate.table_name()) .await .unwrap() .iter() diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index cb4f72c340..b287ec45e5 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -44,18 +44,16 @@ impl MigrateDatabase for Any { } impl Migrate for AnyConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { - Box::pin(async { self.get_migrate()?.ensure_migrations_table().await }) + fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async { self.get_migrate()?.ensure_migrations_table(table_name).await }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async { self.get_migrate()?.dirty_version().await }) + fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async { self.get_migrate()?.dirty_version(table_name).await }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async { self.get_migrate()?.list_applied_migrations().await }) + fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async { self.get_migrate()?.list_applied_migrations(table_name).await }) } fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { @@ -66,17 +64,19 @@ impl Migrate for AnyConnection { Box::pin(async { self.get_migrate()?.unlock().await }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async { self.get_migrate()?.apply(migration).await }) + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async { self.get_migrate()?.apply(table_name, migration).await }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async { self.get_migrate()?.revert(migration).await }) + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async { self.get_migrate()?.revert(table_name, migration).await }) } } diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index 666ed5bf92..a70938b209 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -186,6 +186,10 @@ impl Config { self.migrations_dir.as_deref().unwrap_or("migrations") } + pub fn table_name(&self) -> &str { + self.table_name.as_deref().unwrap_or("_sqlx_migrations") + } + pub fn to_resolve_config(&self) -> crate::migrate::ResolveConfig { let mut config = crate::migrate::ResolveConfig::new(); config.ignore_chars(self.ignored_chars.iter().copied()); diff --git a/sqlx-core/src/migrate/migrate.rs b/sqlx-core/src/migrate/migrate.rs index 0e4448a9bd..2258f06f04 100644 --- a/sqlx-core/src/migrate/migrate.rs +++ b/sqlx-core/src/migrate/migrate.rs @@ -27,16 +27,14 @@ pub trait MigrateDatabase { pub trait Migrate { // ensure migrations table exists // will create or migrate it if needed - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>>; + fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>>; // Return the version on which the database is dirty or None otherwise. // "dirty" means there is a partially applied migration that failed. - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>>; + fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>>; // Return the ordered list of applied migrations - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>>; + fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>>; // Should acquire a database lock so that only one migration process // can run at a time. [`Migrate`] will call this function before applying @@ -50,16 +48,18 @@ pub trait Migrate { // run SQL from migration in a DDL transaction // insert new row to [_migrations] table on completion (success or failure) // returns the time taking to run the migration SQL - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result>; + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result>; // run a revert SQL from migration in a DDL transaction // deletes the row in [_migrations] table with specified migration version on completion (success or failure) // returns the time taking to run the migration SQL - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result>; + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result>; } diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 42cc3095f8..aa737ad304 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -27,25 +27,6 @@ pub struct Migrator { pub table_name: Cow<'static, str>, } -fn validate_applied_migrations( - applied_migrations: &[AppliedMigration], - migrator: &Migrator, -) -> Result<(), MigrateError> { - if migrator.ignore_missing { - return Ok(()); - } - - let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); - - for applied_migration in applied_migrations { - if !migrations.contains(&applied_migration.version) { - return Err(MigrateError::VersionMissing(applied_migration.version)); - } - } - - Ok(()) -} - impl Migrator { #[doc(hidden)] pub const DEFAULT: Migrator = Migrator { @@ -156,12 +137,21 @@ impl Migrator { ::Target: Migrate, { let mut conn = migrator.acquire().await?; - self.run_direct(&mut *conn).await + self.run_direct(None, &mut *conn).await + } + + pub async fn run_to<'a, A>(&self, target: i64, migrator: A) -> Result<(), MigrateError> + where + A: Acquire<'a>, + ::Target: Migrate, + { + let mut conn = migrator.acquire().await?; + self.run_direct(Some(target), &mut *conn).await } // Getting around the annoying "implementation of `Acquire` is not general enough" error #[doc(hidden)] - pub async fn run_direct(&self, conn: &mut C) -> Result<(), MigrateError> + pub async fn run_direct(&self, target: Option, conn: &mut C) -> Result<(), MigrateError> where C: Migrate, { @@ -172,14 +162,14 @@ impl Migrator { // creates [_migrations] table only if needed // eventually this will likely migrate previous versions of the table - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(&self.table_name).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(&self.table_name).await?; if let Some(version) = version { return Err(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(&self.table_name).await?; validate_applied_migrations(&applied_migrations, self)?; let applied_migrations: HashMap<_, _> = applied_migrations @@ -188,6 +178,11 @@ impl Migrator { .collect(); for migration in self.iter() { + if target.is_some_and(|target| target < migration.version) { + // Target version reached + break; + } + if migration.migration_type.is_down_migration() { continue; } @@ -199,7 +194,7 @@ impl Migrator { } } None => { - conn.apply(migration).await?; + conn.apply(&self.table_name, migration).await?; } } } @@ -244,14 +239,14 @@ impl Migrator { // creates [_migrations] table only if needed // eventually this will likely migrate previous versions of the table - conn.ensure_migrations_table().await?; + conn.ensure_migrations_table(&self.table_name).await?; - let version = conn.dirty_version().await?; + let version = conn.dirty_version(&self.table_name).await?; if let Some(version) = version { return Err(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations = conn.list_applied_migrations(&self.table_name).await?; validate_applied_migrations(&applied_migrations, self)?; let applied_migrations: HashMap<_, _> = applied_migrations @@ -266,7 +261,7 @@ impl Migrator { .filter(|m| applied_migrations.contains_key(&m.version)) .filter(|m| m.version > target) { - conn.revert(migration).await?; + conn.revert(&self.table_name, migration).await?; } // unlock the migrator to allow other migrators to run @@ -278,3 +273,22 @@ impl Migrator { Ok(()) } } + +fn validate_applied_migrations( + applied_migrations: &[AppliedMigration], + migrator: &Migrator, +) -> Result<(), MigrateError> { + if migrator.ignore_missing { + return Ok(()); + } + + let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); + + for applied_migration in applied_migrations { + if !migrations.contains(&applied_migration.version) { + return Err(MigrateError::VersionMissing(applied_migration.version)); + } + } + + Ok(()) +} \ No newline at end of file diff --git a/sqlx-core/src/testing/mod.rs b/sqlx-core/src/testing/mod.rs index d82d1a3616..9db65e9de3 100644 --- a/sqlx-core/src/testing/mod.rs +++ b/sqlx-core/src/testing/mod.rs @@ -243,7 +243,7 @@ async fn setup_test_db( if let Some(migrator) = args.migrator { migrator - .run_direct(&mut conn) + .run_direct(None, &mut conn) .await .expect("failed to apply migrations"); } diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index 79b55ace3c..f0d0d6a029 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -4,7 +4,6 @@ use std::time::Instant; use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::*; - use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; @@ -75,12 +74,12 @@ impl MigrateDatabase for MySql { } impl Migrate for MySqlConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=MySQL self.execute( - r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( + &*format!(r#" +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -88,7 +87,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#, + "#), ) .await?; @@ -96,11 +95,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", + &format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), ) .fetch_optional(self) .await?; @@ -109,13 +108,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + query_as(&format!("SELECT version, checksum FROM {table_name} ORDER BY version")) .fetch_all(self) .await?; @@ -167,10 +164,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -188,10 +186,10 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // // language=MySQL let _ = query( - r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + &format!(r#" + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?, ?, FALSE, ?, -1 ) - "#, + "#), ) .bind(migration.version) .bind(&*migration.description) @@ -206,11 +204,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=MySQL let _ = query( - r#" - UPDATE _sqlx_migrations + &format!(r#" + UPDATE {table_name} SET success = TRUE WHERE version = ? - "#, + "#), ) .bind(migration.version) .execute(&mut *tx) @@ -226,11 +224,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( #[allow(clippy::cast_possible_truncation)] let _ = query( - r#" - UPDATE _sqlx_migrations + &format!(r#" + UPDATE {table_name} SET execution_time = ? WHERE version = ? - "#, + "#), ) .bind(elapsed.as_nanos() as i64) .bind(migration.version) @@ -241,10 +239,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -259,11 +258,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // // language=MySQL let _ = query( - r#" - UPDATE _sqlx_migrations + &format!(r#" + UPDATE {table_name} SET success = FALSE WHERE version = ? - "#, + "#), ) .bind(migration.version) .execute(&mut *tx) @@ -272,7 +271,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( tx.execute(&*migration.sql).await?; // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?"#) + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?"#)) .bind(migration.version) .execute(&mut *tx) .await?; diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c37e92f4d6..2646466399 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -111,12 +111,12 @@ impl MigrateDatabase for Postgres { } impl Migrate for PgConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL self.execute( - r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( + &*format!(r#" +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMPTZ NOT NULL DEFAULT now(), @@ -124,7 +124,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BYTEA NOT NULL, execution_time BIGINT NOT NULL ); - "#, + "#), ) .await?; @@ -132,11 +132,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", + &*format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), ) .fetch_optional(self) .await?; @@ -145,13 +145,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + query_as(&*format!("SELECT version, checksum FROM {table_name} ORDER BY version")) .fetch_all(self) .await?; @@ -203,16 +201,17 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let start = Instant::now(); // execute migration queries if migration.no_tx { - execute_migration(self, migration).await?; + execute_migration(self, table_name, migration).await?; } else { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -220,7 +219,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let mut tx = self.begin().await?; - execute_migration(&mut tx, migration).await?; + execute_migration(&mut tx, table_name, migration).await?; tx.commit().await?; } @@ -232,11 +231,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=SQL #[allow(clippy::cast_possible_truncation)] let _ = query( - r#" - UPDATE _sqlx_migrations + &*format!(r#" + UPDATE {table_name} SET execution_time = $1 WHERE version = $2 - "#, + "#), ) .bind(elapsed.as_nanos() as i64) .bind(migration.version) @@ -247,21 +246,22 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let start = Instant::now(); // execute migration queries if migration.no_tx { - revert_migration(self, migration).await?; + revert_migration(self, table_name, migration).await?; } else { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. let mut tx = self.begin().await?; - revert_migration(&mut tx, migration).await?; + revert_migration(&mut tx, table_name, migration).await?; tx.commit().await?; } @@ -274,6 +274,7 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( async fn execute_migration( conn: &mut PgConnection, + table_name: &str, migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn @@ -283,10 +284,10 @@ async fn execute_migration( // language=SQL let _ = query( - r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + &*format!(r#" + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) - "#, + "#), ) .bind(migration.version) .bind(&*migration.description) @@ -299,6 +300,7 @@ async fn execute_migration( async fn revert_migration( conn: &mut PgConnection, + table_name: &str, migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn @@ -307,7 +309,7 @@ async fn revert_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) + let _ = query(&*format!(r#"DELETE FROM {table_name} WHERE version = $1"#)) .bind(migration.version) .execute(conn) .await?; diff --git a/sqlx-sqlite/src/migrate.rs b/sqlx-sqlite/src/migrate.rs index b9ce22dccd..8b5c24744c 100644 --- a/sqlx-sqlite/src/migrate.rs +++ b/sqlx-sqlite/src/migrate.rs @@ -64,12 +64,11 @@ impl MigrateDatabase for Sqlite { } impl Migrate for SqliteConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQLite - self.execute( - r#" -CREATE TABLE IF NOT EXISTS _sqlx_migrations ( + self.execute(&*format!(r#" +CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -77,19 +76,19 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#, + "#), ) - .await?; + .await?; Ok(()) }) } - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite let row: Option<(i64,)> = query_as( - "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", + &format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), ) .fetch_optional(self) .await?; @@ -98,13 +97,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { + fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite let rows: Vec<(i64, Vec)> = - query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + query_as(&format!("SELECT version, checksum FROM {table_name} ORDER BY version")) .fetch_all(self) .await?; @@ -128,10 +125,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( Box::pin(async move { Ok(()) }) } - fn apply<'e: 'm, 'm>( + fn apply<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { let mut tx = self.begin().await?; let start = Instant::now(); @@ -148,10 +146,10 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=SQL let _ = query( - r#" - INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) + &format!(r#" + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?1, ?2, TRUE, ?3, -1 ) - "#, + "#), ) .bind(migration.version) .bind(&*migration.description) @@ -170,11 +168,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( // language=SQL #[allow(clippy::cast_possible_truncation)] let _ = query( - r#" - UPDATE _sqlx_migrations + &format!(r#" + UPDATE {table_name} SET execution_time = ?1 WHERE version = ?2 - "#, + "#), ) .bind(elapsed.as_nanos() as i64) .bind(migration.version) @@ -185,10 +183,11 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn revert<'e: 'm, 'm>( + fn revert<'e>( &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. @@ -197,8 +196,8 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( let _ = tx.execute(&*migration.sql).await?; - // language=SQL - let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?1"#) + // language=SQLite + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?1"#)) .bind(migration.version) .execute(&mut *tx) .await?; From 3765f67aba1d078b31fc86c59e6f711e980dd94a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 22 Jan 2025 15:32:50 -0800 Subject: [PATCH 16/30] feat: introduce `migrate.create-schemas` --- sqlx-cli/src/database.rs | 8 ++- sqlx-cli/src/lib.rs | 24 ++++--- sqlx-cli/src/migrate.rs | 68 +++++++++++++----- sqlx-cli/src/opt.rs | 67 +++++++++-------- sqlx-cli/tests/common/mod.rs | 2 +- sqlx-core/src/any/migrate.rs | 38 ++++++++-- sqlx-core/src/config/migrate.rs | 20 +++++- sqlx-core/src/migrate/error.rs | 3 + sqlx-core/src/migrate/migrate.rs | 21 +++++- sqlx-core/src/migrate/migrator.rs | 25 ++++++- sqlx-macros-core/src/migrate.rs | 33 +++++---- sqlx-macros-core/src/test_attr.rs | 3 +- sqlx-mysql/src/migrate.rs | 87 ++++++++++++++--------- sqlx-postgres/src/connection/describe.rs | 24 +++---- sqlx-postgres/src/connection/establish.rs | 3 +- sqlx-postgres/src/migrate.rs | 67 +++++++++++------ sqlx-sqlite/src/migrate.rs | 77 ++++++++++++++------ 17 files changed, 384 insertions(+), 186 deletions(-) diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 1fd8bcc534..a0af55d64b 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,5 +1,5 @@ -use crate::{migrate, Config}; use crate::opt::{ConnectOpts, MigrationSourceOpt}; +use crate::{migrate, Config}; use console::style; use promptly::{prompt, ReadlineError}; use sqlx::any::Any; @@ -54,7 +54,11 @@ pub async fn reset( setup(config, migration_source, connect_opts).await } -pub async fn setup(config: &Config, migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts) -> anyhow::Result<()> { +pub async fn setup( + config: &Config, + migration_source: &MigrationSourceOpt, + connect_opts: &ConnectOpts, +) -> anyhow::Result<()> { create(connect_opts).await?; migrate::run(config, migration_source, connect_opts, false, false, None).await } diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 63257f541f..43b301e4c5 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,5 +1,5 @@ use std::io; -use std::path::{PathBuf}; +use std::path::PathBuf; use std::time::Duration; use anyhow::{Context, Result}; @@ -28,7 +28,7 @@ pub async fn run(opt: Opt) -> Result<()> { match opt.command { Command::Migrate(migrate) => match migrate.command { - MigrateCommand::Add(opts)=> migrate::add(config, opts).await?, + MigrateCommand::Add(opts) => migrate::add(config, opts).await?, MigrateCommand::Run { source, dry_run, @@ -74,15 +74,17 @@ pub async fn run(opt: Opt) -> Result<()> { connect_opts.populate_db_url(config)?; migrate::info(config, &source, &connect_opts).await? - }, - MigrateCommand::BuildScript { source, force } => migrate::build_script(config, &source, force)?, + } + MigrateCommand::BuildScript { source, force } => { + migrate::build_script(config, &source, force)? + } }, Command::Database(database) => match database.command { DatabaseCommand::Create { mut connect_opts } => { connect_opts.populate_db_url(config)?; database::create(&connect_opts).await? - }, + } DatabaseCommand::Drop { confirmation, mut connect_opts, @@ -90,7 +92,7 @@ pub async fn run(opt: Opt) -> Result<()> { } => { connect_opts.populate_db_url(config)?; database::drop(&connect_opts, !confirmation.yes, force).await? - }, + } DatabaseCommand::Reset { confirmation, source, @@ -99,14 +101,14 @@ pub async fn run(opt: Opt) -> Result<()> { } => { connect_opts.populate_db_url(config)?; database::reset(config, &source, &connect_opts, !confirmation.yes, force).await? - }, + } DatabaseCommand::Setup { source, mut connect_opts, } => { connect_opts.populate_db_url(config)?; database::setup(config, &source, &connect_opts).await? - }, + } }, Command::Prepare { @@ -118,7 +120,7 @@ pub async fn run(opt: Opt) -> Result<()> { } => { connect_opts.populate_db_url(config)?; prepare::run(check, all, workspace, connect_opts, args).await? - }, + } #[cfg(feature = "completions")] Command::Completions { shell } => completions::run(shell), @@ -183,6 +185,6 @@ async fn config_from_current_dir() -> anyhow::Result<&'static Config> { Config::read_with_or_default(move || Ok(path)) }) - .await - .context("unexpected error loading config") + .await + .context("unexpected error loading config") } diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index 9e0119682e..3618fbe7a3 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -1,7 +1,10 @@ +use crate::config::Config; use crate::opt::{AddMigrationOpts, ConnectOpts, MigrationSourceOpt}; use anyhow::{bail, Context}; use console::style; -use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator, ResolveWith}; +use sqlx::migrate::{ + AppliedMigration, Migrate, MigrateError, MigrationType, Migrator, ResolveWith, +}; use sqlx::Connection; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -9,14 +12,10 @@ use std::fmt::Write; use std::fs::{self, File}; use std::path::Path; use std::time::Duration; -use crate::config::Config; -pub async fn add( - config: &Config, - opts: AddMigrationOpts, -) -> anyhow::Result<()> { +pub async fn add(config: &Config, opts: AddMigrationOpts) -> anyhow::Result<()> { let source = opts.source.resolve(config); - + fs::create_dir_all(source).context("Unable to create migrations directory")?; let migrator = Migrator::new(Path::new(source)).await?; @@ -124,13 +123,27 @@ fn short_checksum(checksum: &[u8]) -> String { s } -pub async fn info(config: &Config, migration_source: &MigrationSourceOpt, connect_opts: &ConnectOpts) -> anyhow::Result<()> { +pub async fn info( + config: &Config, + migration_source: &MigrationSourceOpt, + connect_opts: &ConnectOpts, +) -> anyhow::Result<()> { let source = migration_source.resolve(config); - - let migrator = Migrator::new(ResolveWith(Path::new(source), config.migrate.to_resolve_config())).await?; + + let migrator = Migrator::new(ResolveWith( + Path::new(source), + config.migrate.to_resolve_config(), + )) + .await?; let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table(config.migrate.table_name()).await?; + // FIXME: we shouldn't actually be creating anything here + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } + + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; let applied_migrations: HashMap<_, _> = conn .list_applied_migrations(config.migrate.table_name()) @@ -214,7 +227,7 @@ pub async fn run( target_version: Option, ) -> anyhow::Result<()> { let source = migration_source.resolve(config); - + let migrator = Migrator::new(Path::new(source)).await?; if let Some(target_version) = target_version { if !migrator.version_exists(target_version) { @@ -224,14 +237,21 @@ pub async fn run( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table(config.migrate.table_name()).await?; + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } + + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations(config.migrate.table_name()).await?; + let applied_migrations = conn + .list_applied_migrations(config.migrate.table_name()) + .await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -319,14 +339,22 @@ pub async fn revert( let mut conn = crate::connect(connect_opts).await?; - conn.ensure_migrations_table(config.migrate.table_name()).await?; + // FIXME: we should not be creating anything here if it doesn't exist + for schema_name in &config.migrate.create_schemas { + conn.create_schema_if_not_exists(schema_name).await?; + } + + conn.ensure_migrations_table(config.migrate.table_name()) + .await?; let version = conn.dirty_version(config.migrate.table_name()).await?; if let Some(version) = version { bail!(MigrateError::Dirty(version)); } - let applied_migrations = conn.list_applied_migrations(config.migrate.table_name()).await?; + let applied_migrations = conn + .list_applied_migrations(config.migrate.table_name()) + .await?; validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?; let latest_version = applied_migrations @@ -397,9 +425,13 @@ pub async fn revert( Ok(()) } -pub fn build_script(config: &Config, migration_source: &MigrationSourceOpt, force: bool) -> anyhow::Result<()> { +pub fn build_script( + config: &Config, + migration_source: &MigrationSourceOpt, + force: bool, +) -> anyhow::Result<()> { let source = migration_source.resolve(config); - + anyhow::ensure!( Path::new("Cargo.toml").exists(), "must be run in a Cargo project root" diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 0b72af6594..9716303cce 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -1,13 +1,13 @@ -use std::env; -use std::ops::{Deref, Not}; +use crate::config::migrate::{DefaultMigrationType, DefaultVersioning}; +use crate::config::Config; use anyhow::Context; use chrono::Utc; use clap::{Args, Parser}; #[cfg(feature = "completions")] use clap_complete::Shell; -use crate::config::Config; use sqlx::migrate::Migrator; -use crate::config::migrate::{DefaultMigrationType, DefaultVersioning}; +use std::env; +use std::ops::{Deref, Not}; #[derive(Parser, Debug)] #[clap(version, about, author)] @@ -129,7 +129,7 @@ pub enum MigrateCommand { /// Create a new migration with the given description. /// /// -------------------------------- - /// + /// /// Migrations may either be simple, or reversible. /// /// Reversible migrations can be reverted with `sqlx migrate revert`, simple migrations cannot. @@ -152,7 +152,7 @@ pub enum MigrateCommand { /// It is recommended to always back up the database before running migrations. /// /// -------------------------------- - /// + /// /// For convenience, this command attempts to detect if reversible migrations are in-use. /// /// If the latest existing migration is reversible, the new migration will also be reversible. @@ -164,7 +164,7 @@ pub enum MigrateCommand { /// The default type to use can also be set in `sqlx.toml`. /// /// -------------------------------- - /// + /// /// A version number will be automatically assigned to the migration. /// /// Migrations are applied in ascending order by version number. @@ -174,9 +174,9 @@ pub enum MigrateCommand { /// less than _any_ previously applied migration. /// /// Migrations should only be created with increasing version number. - /// + /// /// -------------------------------- - /// + /// /// For convenience, this command will attempt to detect if sequential versioning is in use, /// and if so, continue the sequence. /// @@ -290,7 +290,7 @@ pub struct AddMigrationOpts { #[derive(Args, Debug)] pub struct MigrationSourceOpt { /// Path to folder containing migrations. - /// + /// /// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`. #[clap(long)] pub source: Option, @@ -301,7 +301,7 @@ impl MigrationSourceOpt { if let Some(source) = &self.source { return source; } - + config.migrate.migrations_dir() } } @@ -335,7 +335,9 @@ impl ConnectOpts { /// Require a database URL to be provided, otherwise /// return an error. pub fn expect_db_url(&self) -> anyhow::Result<&str> { - self.database_url.as_deref().context("BUG: database_url not populated") + self.database_url + .as_deref() + .context("BUG: database_url not populated") } /// Populate `database_url` from the environment, if not set. @@ -359,7 +361,7 @@ impl ConnectOpts { } self.database_url = Some(url) - }, + } Err(env::VarError::NotPresent) => { anyhow::bail!("`--database-url` or `{var}`{context} must be set") } @@ -407,22 +409,20 @@ impl Not for IgnoreMissing { impl AddMigrationOpts { pub fn reversible(&self, config: &Config, migrator: &Migrator) -> bool { - if self.reversible { return true; } - if self.simple { return false; } + if self.reversible { + return true; + } + if self.simple { + return false; + } match config.migrate.defaults.migration_type { - DefaultMigrationType::Inferred => { - migrator - .iter() - .last() - .is_some_and(|m| m.migration_type.is_reversible()) - } - DefaultMigrationType::Simple => { - false - } - DefaultMigrationType::Reversible => { - true - } + DefaultMigrationType::Inferred => migrator + .iter() + .last() + .is_some_and(|m| m.migration_type.is_reversible()), + DefaultMigrationType::Simple => false, + DefaultMigrationType::Reversible => true, } } @@ -434,8 +434,7 @@ impl AddMigrationOpts { } if self.sequential || matches!(default_versioning, DefaultVersioning::Sequential) { - return next_sequential(migrator) - .unwrap_or_else(|| fmt_sequential(1)); + return next_sequential(migrator).unwrap_or_else(|| fmt_sequential(1)); } next_sequential(migrator).unwrap_or_else(next_timestamp) @@ -455,18 +454,16 @@ fn next_sequential(migrator: &Migrator) -> Option { match migrations { [previous, latest] => { // If the latest two versions differ by 1, infer sequential. - (latest.version - previous.version == 1) - .then_some(latest.version + 1) - }, + (latest.version - previous.version == 1).then_some(latest.version + 1) + } [latest] => { // If only one migration exists and its version is 0 or 1, infer sequential - matches!(latest.version, 0 | 1) - .then_some(latest.version + 1) + matches!(latest.version, 0 | 1).then_some(latest.version + 1) } _ => unreachable!(), } }); - + next_version.map(fmt_sequential) } diff --git a/sqlx-cli/tests/common/mod.rs b/sqlx-cli/tests/common/mod.rs index bb58554f33..26f041d68a 100644 --- a/sqlx-cli/tests/common/mod.rs +++ b/sqlx-cli/tests/common/mod.rs @@ -1,12 +1,12 @@ use assert_cmd::{assert::Assert, Command}; +use sqlx::_unstable::config::Config; use sqlx::{migrate::Migrate, Connection, SqliteConnection}; use std::{ env::temp_dir, fs::remove_file, path::{Path, PathBuf}, }; -use sqlx::_unstable::config::Config; pub struct TestDatabase { file_path: PathBuf, diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index b287ec45e5..69b5bf6ab6 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -44,16 +44,44 @@ impl MigrateDatabase for Any { } impl Migrate for AnyConnection { - fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { - Box::pin(async { self.get_migrate()?.ensure_migrations_table(table_name).await }) + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async { + self.get_migrate()? + .create_schema_if_not_exists(schema_name) + .await + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async { + self.get_migrate()? + .ensure_migrations_table(table_name) + .await + }) } - fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async { self.get_migrate()?.dirty_version(table_name).await }) } - fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { - Box::pin(async { self.get_migrate()?.list_applied_migrations(table_name).await }) + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async { + self.get_migrate()? + .list_applied_migrations(table_name) + .await + }) } fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index a70938b209..4865e24c76 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -19,6 +19,20 @@ use std::collections::BTreeSet; serde(default, rename_all = "kebab-case") )] pub struct Config { + /// Specify the names of schemas to create if they don't already exist. + /// + /// This is done before checking the existence of the migrations table + /// (`_sqlx_migrations` or overridden `table_name` below) so that it may be placed in + /// one of these schemas. + /// + /// ### Example + /// `sqlx.toml`: + /// ```toml + /// [migrate] + /// create-schemas = ["foo"] + /// ``` + pub create_schemas: BTreeSet>, + /// Override the name of the table used to track executed migrations. /// /// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`. @@ -185,14 +199,14 @@ impl Config { pub fn migrations_dir(&self) -> &str { self.migrations_dir.as_deref().unwrap_or("migrations") } - + pub fn table_name(&self) -> &str { self.table_name.as_deref().unwrap_or("_sqlx_migrations") } - + pub fn to_resolve_config(&self) -> crate::migrate::ResolveConfig { let mut config = crate::migrate::ResolveConfig::new(); config.ignore_chars(self.ignored_chars.iter().copied()); config } -} \ No newline at end of file +} diff --git a/sqlx-core/src/migrate/error.rs b/sqlx-core/src/migrate/error.rs index 608d55b18d..a04243963a 100644 --- a/sqlx-core/src/migrate/error.rs +++ b/sqlx-core/src/migrate/error.rs @@ -39,4 +39,7 @@ pub enum MigrateError { "migration {0} is partially applied; fix and remove row from `_sqlx_migrations` table" )] Dirty(i64), + + #[error("database driver does not support creation of schemas at migrate time: {0}")] + CreateSchemasNotSupported(String), } diff --git a/sqlx-core/src/migrate/migrate.rs b/sqlx-core/src/migrate/migrate.rs index 2258f06f04..841f775966 100644 --- a/sqlx-core/src/migrate/migrate.rs +++ b/sqlx-core/src/migrate/migrate.rs @@ -25,16 +25,31 @@ pub trait MigrateDatabase { // 'e = Executor pub trait Migrate { + /// Create a database schema with the given name if it does not already exist. + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>>; + // ensure migrations table exists // will create or migrate it if needed - fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>>; + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>>; // Return the version on which the database is dirty or None otherwise. // "dirty" means there is a partially applied migration that failed. - fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>>; + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>>; // Return the ordered list of applied migrations - fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>>; + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>>; // Should acquire a database lock so that only one migration process // can run at a time. [`Migrate`] will call this function before applying diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index aa737ad304..0f5cfb3fd7 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -25,6 +25,9 @@ pub struct Migrator { pub no_tx: bool, #[doc(hidden)] pub table_name: Cow<'static, str>, + + #[doc(hidden)] + pub create_schemas: Cow<'static, [Cow<'static, str>]>, } impl Migrator { @@ -35,6 +38,7 @@ impl Migrator { no_tx: false, locking: true, table_name: Cow::Borrowed("_sqlx_migrations"), + create_schemas: Cow::Borrowed(&[]), }; /// Creates a new instance with the given source. @@ -84,6 +88,19 @@ impl Migrator { self } + /// Add a schema name to be created if it does not already exist. + /// + /// May be used with [`Self::dangerous_set_table_name()`] to place the migrations table + /// in a new schema without requiring it to exist first. + /// + /// ### Note: Support Depends on Database + /// SQLite cannot create new schemas without attaching them to a database file, + /// the path of which must be specified separately in an [`ATTACH DATABASE`](https://www.sqlite.org/lang_attach.html) command. + pub fn create_schema(&mut self, schema_name: impl Into>) -> &Self { + self.create_schemas.to_mut().push(schema_name.into()); + self + } + /// Specify whether applied migrations that are missing from the resolved migrations should be ignored. pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self { self.ignore_missing = ignore_missing; @@ -160,6 +177,10 @@ impl Migrator { conn.lock().await?; } + for schema_name in self.create_schemas.iter() { + conn.create_schema_if_not_exists(schema_name).await?; + } + // creates [_migrations] table only if needed // eventually this will likely migrate previous versions of the table conn.ensure_migrations_table(&self.table_name).await?; @@ -182,7 +203,7 @@ impl Migrator { // Target version reached break; } - + if migration.migration_type.is_down_migration() { continue; } @@ -291,4 +312,4 @@ fn validate_applied_migrations( } Ok(()) -} \ No newline at end of file +} diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index 0ae2eaebda..2f0e92bc88 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -5,10 +5,10 @@ use std::path::{Path, PathBuf}; use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; -use syn::LitStr; -use syn::spanned::Spanned; use sqlx_core::config::Config; use sqlx_core::migrate::{Migration, MigrationType, ResolveConfig}; +use syn::spanned::Spanned; +use syn::LitStr; pub const DEFAULT_PATH: &str = "./migrations"; @@ -85,7 +85,9 @@ impl ToTokens for QuoteMigration { } pub fn default_path(config: &Config) -> &str { - config.migrate.migrations_dir + config + .migrate + .migrations_dir .as_deref() .unwrap_or(DEFAULT_PATH) } @@ -93,12 +95,10 @@ pub fn default_path(config: &Config) -> &str { pub fn expand(path_arg: Option) -> crate::Result { let config = Config::from_crate(); - let path = match path_arg { - Some(path_arg) => crate::common::resolve_path(path_arg.value(), path_arg.span())?, - None => { - crate::common::resolve_path(default_path(config), Span::call_site()) - }? - }; + let path = match path_arg { + Some(path_arg) => crate::common::resolve_path(path_arg.value(), path_arg.span())?, + None => { crate::common::resolve_path(default_path(config), Span::call_site()) }?, + }; expand_with_path(config, &path) } @@ -130,18 +130,21 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result crate::Result { let path = crate::migrate::default_path(config); - let resolved_path = - crate::common::resolve_path(path, proc_macro2::Span::call_site())?; + let resolved_path = crate::common::resolve_path(path, proc_macro2::Span::call_site())?; if resolved_path.is_dir() { let migrator = crate::migrate::expand_with_path(config, &resolved_path)?; diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index f0d0d6a029..45ca7d98ef 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -2,8 +2,6 @@ use std::str::FromStr; use std::time::Duration; use std::time::Instant; -use futures_core::future::BoxFuture; -pub(crate) use sqlx_core::migrate::*; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; @@ -11,6 +9,8 @@ use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::{MySql, MySqlConnectOptions, MySqlConnection}; +use futures_core::future::BoxFuture; +pub(crate) use sqlx_core::migrate::*; fn parse_for_maintenance(url: &str) -> Result<(MySqlConnectOptions, String), Error> { let mut options = MySqlConnectOptions::from_str(url)?; @@ -74,11 +74,27 @@ impl MigrateDatabase for MySql { } impl Migrate for MySqlConnection { - fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // language=SQL + self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=MySQL - self.execute( - &*format!(r#" + self.execute(&*format!( + r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, @@ -87,20 +103,23 @@ CREATE TABLE IF NOT EXISTS {table_name} ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#), - ) + "# + )) .await?; Ok(()) }) } - fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as( - &format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), - ) + let row: Option<(i64,)> = query_as(&format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -108,13 +127,17 @@ CREATE TABLE IF NOT EXISTS {table_name} ( }) } - fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = - query_as(&format!("SELECT version, checksum FROM {table_name} ORDER BY version")) - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -185,12 +208,12 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // `success=FALSE` and later modify the flag. // // language=MySQL - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?, ?, FALSE, ?, -1 ) - "#), - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -203,13 +226,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=MySQL - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" UPDATE {table_name} SET success = TRUE WHERE version = ? - "#), - ) + "# + )) .bind(migration.version) .execute(&mut *tx) .await?; @@ -223,13 +246,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( let elapsed = start.elapsed(); #[allow(clippy::cast_possible_truncation)] - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" UPDATE {table_name} SET execution_time = ? WHERE version = ? - "#), - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -257,13 +280,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // `success=FALSE` and later remove the migration altogether. // // language=MySQL - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" UPDATE {table_name} SET success = FALSE WHERE version = ? - "#), - ) + "# + )) .bind(migration.version) .execute(&mut *tx) .await?; diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 5b6a2aa09c..8119e2e97b 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -208,18 +208,18 @@ impl PgConnection { attribute_no: i16, should_fetch: bool, ) -> Result { - if let Some(origin) = - self.inner - .cache_table_to_column_names - .get(&relation_id) - .and_then(|table_columns| { - let column_name = table_columns.columns.get(&attribute_no).cloned()?; - - Some(ColumnOrigin::Table(TableColumn { - table: table_columns.table_name.clone(), - name: column_name, - })) - }) + if let Some(origin) = self + .inner + .cache_table_to_column_names + .get(&relation_id) + .and_then(|table_columns| { + let column_name = table_columns.columns.get(&attribute_no).cloned()?; + + Some(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name, + })) + }) { return Ok(origin); } diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 684bf26599..634b71de4b 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -149,7 +149,8 @@ impl PgConnection { cache_type_info: HashMap::new(), cache_elem_type_to_array: HashMap::new(), cache_table_to_column_names: HashMap::new(), - log_settings: options.log_settings.clone(),}), + log_settings: options.log_settings.clone(), + }), }) } } diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index 2646466399..90ebd49a73 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -111,11 +111,27 @@ impl MigrateDatabase for Postgres { } impl Migrate for PgConnection { - fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // language=SQL + self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL - self.execute( - &*format!(r#" + self.execute(&*format!( + r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, @@ -124,20 +140,23 @@ CREATE TABLE IF NOT EXISTS {table_name} ( checksum BYTEA NOT NULL, execution_time BIGINT NOT NULL ); - "#), - ) + "# + )) .await?; Ok(()) }) } - fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as( - &*format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), - ) + let row: Option<(i64,)> = query_as(&*format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -145,13 +164,17 @@ CREATE TABLE IF NOT EXISTS {table_name} ( }) } - fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = - query_as(&*format!("SELECT version, checksum FROM {table_name} ORDER BY version")) - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&*format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -230,13 +253,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query( - &*format!(r#" + let _ = query(&*format!( + r#" UPDATE {table_name} SET execution_time = $1 WHERE version = $2 - "#), - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -283,12 +306,12 @@ async fn execute_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query( - &*format!(r#" + let _ = query(&*format!( + r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) - "#), - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) diff --git a/sqlx-sqlite/src/migrate.rs b/sqlx-sqlite/src/migrate.rs index 8b5c24744c..e475f70308 100644 --- a/sqlx-sqlite/src/migrate.rs +++ b/sqlx-sqlite/src/migrate.rs @@ -15,6 +15,7 @@ use std::time::Duration; use std::time::Instant; pub(crate) use sqlx_core::migrate::*; +use sqlx_core::query_scalar::query_scalar; impl MigrateDatabase for Sqlite { fn create_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { @@ -64,10 +65,35 @@ impl MigrateDatabase for Sqlite { } impl Migrate for SqliteConnection { - fn ensure_migrations_table<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result<(), MigrateError>> { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + // Check if the schema already exists; if so, don't error. + let schema_version: Option = + query_scalar(&format!("PRAGMA {schema_name}.schema_version")) + .fetch_optional(&mut *self) + .await?; + + if schema_version.is_some() { + return Ok(()); + } + + Err(MigrateError::CreateSchemasNotSupported( + format!("cannot create new schema {schema_name}; creation of additional schemas in SQLite requires attaching extra database files"), + )) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQLite - self.execute(&*format!(r#" + self.execute(&*format!( + r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, @@ -76,20 +102,23 @@ CREATE TABLE IF NOT EXISTS {table_name} ( checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); - "#), - ) - .await?; + "# + )) + .await?; Ok(()) }) } - fn dirty_version<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let row: Option<(i64,)> = query_as( - &format!("SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"), - ) + let row: Option<(i64,)> = query_as(&format!( + "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" + )) .fetch_optional(self) .await?; @@ -97,13 +126,17 @@ CREATE TABLE IF NOT EXISTS {table_name} ( }) } - fn list_applied_migrations<'e>(&'e mut self, table_name: &'e str) -> BoxFuture<'e, Result, MigrateError>> { + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let rows: Vec<(i64, Vec)> = - query_as(&format!("SELECT version, checksum FROM {table_name} ORDER BY version")) - .fetch_all(self) - .await?; + let rows: Vec<(i64, Vec)> = query_as(&format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + )) + .fetch_all(self) + .await?; let migrations = rows .into_iter() @@ -145,12 +178,12 @@ CREATE TABLE IF NOT EXISTS {table_name} ( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?1, ?2, TRUE, ?3, -1 ) - "#), - ) + "# + )) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -167,13 +200,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query( - &format!(r#" + let _ = query(&format!( + r#" UPDATE {table_name} SET execution_time = ?1 WHERE version = ?2 - "#), - ) + "# + )) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) From 28b64509bd1ceba28f584e2ab70b6d5e5900765b Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sun, 26 Jan 2025 01:00:46 -0800 Subject: [PATCH 17/30] WIP feat: create multi-tenant database example --- Cargo.lock | 65 ++++++++- Cargo.toml | 1 + .../postgres/axum-multi-tenant/Cargo.toml | 18 +++ examples/postgres/axum-multi-tenant/README.md | 11 ++ .../axum-multi-tenant/accounts/Cargo.toml | 13 ++ .../accounts/migrations/01_setup.sql | 0 .../accounts/migrations/02_account.sql | 8 ++ .../axum-multi-tenant/accounts/sqlx.toml | 6 + .../axum-multi-tenant/accounts/src/lib.rs | 133 ++++++++++++++++++ .../axum-multi-tenant/payments/Cargo.toml | 7 + .../axum-multi-tenant/payments/src/lib.rs | 14 ++ .../postgres/axum-multi-tenant/src/main.rs | 3 + 12 files changed, 272 insertions(+), 7 deletions(-) create mode 100644 examples/postgres/axum-multi-tenant/Cargo.toml create mode 100644 examples/postgres/axum-multi-tenant/README.md create mode 100644 examples/postgres/axum-multi-tenant/accounts/Cargo.toml create mode 100644 examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql create mode 100644 examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql create mode 100644 examples/postgres/axum-multi-tenant/accounts/sqlx.toml create mode 100644 examples/postgres/axum-multi-tenant/accounts/src/lib.rs create mode 100644 examples/postgres/axum-multi-tenant/payments/Cargo.toml create mode 100644 examples/postgres/axum-multi-tenant/payments/src/lib.rs create mode 100644 examples/postgres/axum-multi-tenant/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 30dd3be30d..776676c312 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,17 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "accounts" +version = "0.1.0" +dependencies = [ + "argon2 0.5.3", + "sqlx", + "thiserror 1.0.69", + "tokio", + "uuid", +] + [[package]] name = "addr2line" version = "0.24.2" @@ -127,7 +138,19 @@ checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73" dependencies = [ "base64ct", "blake2", - "password-hash", + "password-hash 0.4.2", +] + +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash 0.5.0", ] [[package]] @@ -1261,7 +1284,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2102,7 +2125,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -2301,6 +2324,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "multi-tenant" +version = "0.8.3" +dependencies = [ + "accounts", + "payments", + "sqlx", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -2547,12 +2579,30 @@ dependencies = [ "subtle", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "payments" +version = "0.1.0" +dependencies = [ + "sqlx", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -3113,7 +3163,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3600,6 +3650,7 @@ dependencies = [ "time", "tokio", "tokio-stream", + "toml", "tracing", "url", "uuid", @@ -3622,7 +3673,7 @@ name = "sqlx-example-postgres-axum-social" version = "0.1.0" dependencies = [ "anyhow", - "argon2", + "argon2 0.4.1", "axum", "dotenvy", "once_cell", @@ -4108,7 +4159,7 @@ dependencies = [ "getrandom", "once_cell", "rustix 0.38.43", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4784,7 +4835,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index cf3352cedd..b75b3c3bd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "sqlx-postgres", "sqlx-sqlite", "examples/mysql/todos", + "examples/postgres/axum-multi-tenant", "examples/postgres/axum-social-with-tests", "examples/postgres/chat", "examples/postgres/files", diff --git a/examples/postgres/axum-multi-tenant/Cargo.toml b/examples/postgres/axum-multi-tenant/Cargo.toml new file mode 100644 index 0000000000..1be607c5b8 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "multi-tenant" +version.workspace = true +license.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[dependencies] +accounts = { path = "accounts" } +payments = { path = "payments" } + +sqlx = { path = "../../..", version = "0.8.3", features = ["runtime-tokio", "postgres"] } + +[lints] +workspace = true diff --git a/examples/postgres/axum-multi-tenant/README.md b/examples/postgres/axum-multi-tenant/README.md new file mode 100644 index 0000000000..d38f7f3ea5 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/README.md @@ -0,0 +1,11 @@ +# Axum App with Multi-tenant Database + +This example project involves three crates, each owning a different schema in one database, +with their own set of migrations. + +* The main crate, an Axum app. + * Owns the `public` schema (tables are referenced unqualified). +* `accounts`: a subcrate simulating a reusable account-management crate. + * Owns schema `accounts`. +* `payments`: a subcrate simulating a wrapper for a payments API. + * Owns schema `payments`. diff --git a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml new file mode 100644 index 0000000000..485ba8eb73 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "accounts" +version = "0.1.0" +edition = "2021" + +[dependencies] +sqlx = { workspace = true, features = ["postgres", "time"] } +argon2 = { version = "0.5.3", features = ["password-hash"] } +tokio = { version = "1", features = ["rt", "sync"] } + +uuid = "1" +thiserror = "1" +rand = "0.8" diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql b/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql new file mode 100644 index 0000000000..91b9cf82e0 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql @@ -0,0 +1,8 @@ +create table account +( + account_id uuid primary key default gen_random_uuid(), + email text unique not null, + password_hash text not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); diff --git a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml new file mode 100644 index 0000000000..45042f1333 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml @@ -0,0 +1,6 @@ +[migrate] +create-schemas = ["accounts"] +migrations-table = "accounts._sqlx_migrations" + +[macros.table-overrides.'accounts.account'] +'account_id' = "crate::AccountId" diff --git a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs new file mode 100644 index 0000000000..f015af3d40 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs @@ -0,0 +1,133 @@ +use std::error::Error; +use argon2::{password_hash, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; + +use password_hash::PasswordHashString; + +use sqlx::{PgConnection, PgTransaction}; +use sqlx::types::Text; + +use uuid::Uuid; + +use tokio::sync::Semaphore; + +#[derive(sqlx::Type)] +#[sqlx(transparent)] +pub struct AccountId(pub Uuid); + + +pub struct AccountsManager { + hashing_semaphore: Semaphore, +} + +#[derive(Debug, thiserror::Error)] +pub enum CreateError { + #[error("email in-use")] + EmailInUse, + General(#[source] + #[from] GeneralError), +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthenticateError { + #[error("unknown email")] + UnknownEmail, + #[error("invalid password")] + InvalidPassword, + General(#[source] + #[from] GeneralError), +} + +#[derive(Debug, thiserror::Error)] +pub enum GeneralError { + Sqlx(#[source] + #[from] sqlx::Error), + PasswordHash(#[source] #[from] argon2::password_hash::Error), + Task(#[source] + #[from] tokio::task::JoinError), +} + +impl AccountsManager { + pub async fn new(conn: &mut PgConnection, max_hashing_threads: usize) -> Result { + sqlx::migrate!().run(conn).await?; + + AccountsManager { + hashing_semaphore: Semaphore::new(max_hashing_threads) + } + } + + async fn hash_password(&self, password: String) -> Result { + let guard = self.hashing_semaphore.acquire().await + .expect("BUG: this semaphore should not be closed"); + + // We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn + // excess threads. + let (_guard, res) = tokio::task::spawn_blocking(move || { + let salt = argon2::password_hash::SaltString::generate(rand::thread_rng()); + (guard, Argon2::default().hash_password(password.as_bytes(), &salt)) + }) + .await?; + + Ok(res?) + } + + async fn verify_password(&self, password: String, hash: PasswordHashString) -> Result<(), AuthenticateError> { + let guard = self.hashing_semaphore.acquire().await + .expect("BUG: this semaphore should not be closed"); + + let (_guard, res) = tokio::task::spawn_blocking(move || { + (guard, Argon2::default().verify_password(password.as_bytes(), &hash.password_hash())) + }).await.map_err(GeneralError::from)?; + + if let Err(password_hash::Error::Password) = res { + return Err(AuthenticateError::InvalidPassword); + } + + res.map_err(GeneralError::from)?; + + Ok(()) + } + + pub async fn create(&self, txn: &mut PgTransaction, email: &str, password: String) -> Result { + // Hash password whether the account exists or not to make it harder + // to tell the difference in the timing. + let hash = self.hash_password(password).await?; + + // language=PostgreSQL + sqlx::query!( + "insert into accounts.account(email, password_hash) \ + values ($1, $2) \ + returning account_id", + email, + Text(hash) as Text>, + ) + .fetch_one(&mut *txn) + .await + .map_err(|e| if e.constraint() == Some("account_account_id_key") { + CreateError::EmailInUse + } else { + GeneralError::from(e).into() + }) + } + + pub async fn authenticate(&self, conn: &mut PgConnection, email: &str, password: String) -> Result { + let maybe_account = sqlx::query!( + "select account_id, password_hash as \"password_hash: Text\" \ + from accounts.account \ + where email_id = $1", + email + ) + .fetch_optional(&mut *conn) + .await + .map_err(GeneralError::from)?; + + let Some(account) = maybe_account else { + // Hash the password whether the account exists or not to hide the difference in timing. + self.hash_password(password).await.map_err(GeneralError::from)?; + return Err(AuthenticateError::UnknownEmail); + }; + + self.verify_password(password, account.password_hash.into())?; + + Ok(account.account_id) + } +} diff --git a/examples/postgres/axum-multi-tenant/payments/Cargo.toml b/examples/postgres/axum-multi-tenant/payments/Cargo.toml new file mode 100644 index 0000000000..0a2485955b --- /dev/null +++ b/examples/postgres/axum-multi-tenant/payments/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "payments" +version = "0.1.0" +edition = "2021" + +[dependencies] +sqlx = { workspace = true, features = ["postgres", "time"] } diff --git a/examples/postgres/axum-multi-tenant/payments/src/lib.rs b/examples/postgres/axum-multi-tenant/payments/src/lib.rs new file mode 100644 index 0000000000..7d12d9af81 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/payments/src/lib.rs @@ -0,0 +1,14 @@ +pub fn add(left: usize, right: usize) -> usize { + left + right +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + let result = add(2, 2); + assert_eq!(result, 4); + } +} diff --git a/examples/postgres/axum-multi-tenant/src/main.rs b/examples/postgres/axum-multi-tenant/src/main.rs new file mode 100644 index 0000000000..e7a11a969c --- /dev/null +++ b/examples/postgres/axum-multi-tenant/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} From 7d646a92510fba12ce053b7be6a142f14c02277a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 1 Feb 2025 23:42:51 -0800 Subject: [PATCH 18/30] fix(postgres): don't fetch `ColumnOrigin` for transparently-prepared statements --- sqlx-postgres/src/connection/describe.rs | 7 ++++--- sqlx-postgres/src/connection/executor.rs | 14 ++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 8119e2e97b..0334357a6c 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -102,7 +102,8 @@ impl PgConnection { pub(super) async fn handle_row_description( &mut self, desc: Option, - should_fetch: bool, + fetch_type_info: bool, + fetch_column_description: bool, ) -> Result<(Vec, HashMap), Error> { let mut columns = Vec::new(); let mut column_names = HashMap::new(); @@ -121,13 +122,13 @@ impl PgConnection { let name = UStr::from(field.name); let type_info = self - .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) + .maybe_fetch_type_info_by_oid(field.data_type_id, fetch_type_info) .await?; let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { - self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch) + self.maybe_fetch_column_origin(relation_oid, attribute_no, fetch_column_description) .await? } else { ColumnOrigin::Expression diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 97503a5004..9b5fd2a3f7 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -25,6 +25,7 @@ async fn prepare( sql: &str, parameters: &[PgTypeInfo], metadata: Option>, + fetch_column_origin: bool, ) -> Result<(StatementId, Arc), Error> { let id = conn.inner.next_statement_id; conn.inner.next_statement_id = id.next(); @@ -79,7 +80,7 @@ async fn prepare( let parameters = conn.handle_parameter_description(parameters).await?; - let (columns, column_names) = conn.handle_row_description(rows, true).await?; + let (columns, column_names) = conn.handle_row_description(rows, true, fetch_column_origin).await?; // ensure that if we did fetch custom data, we wait until we are fully ready before // continuing @@ -168,12 +169,13 @@ impl PgConnection { // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, + fetch_column_origin: bool, ) -> Result<(StatementId, Arc), Error> { if let Some(statement) = self.inner.cache_statement.get_mut(sql) { return Ok((*statement).clone()); } - let statement = prepare(self, sql, parameters, metadata).await?; + let statement = prepare(self, sql, parameters, metadata, fetch_column_origin).await?; if store_to_cache && self.inner.cache_statement.is_enabled() { if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) { @@ -222,7 +224,7 @@ impl PgConnection { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self - .get_or_prepare(query, &arguments.types, persistent, metadata_opt) + .get_or_prepare(query, &arguments.types, persistent, metadata_opt, false) .await?; metadata = metadata_; @@ -327,7 +329,7 @@ impl PgConnection { BackendMessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self - .handle_row_description(Some(message.decode()?), false) + .handle_row_description(Some(message.decode()?), false, false) .await?; metadata = Arc::new(PgStatementMetadata { @@ -449,7 +451,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; + let (_, metadata) = self.get_or_prepare(sql, parameters, true, None, true).await?; Ok(PgStatement { sql: Cow::Borrowed(sql), @@ -468,7 +470,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; + let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None, true).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; From c2b9f87e3152d91e7ae2552766dac35a36b17b07 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 1 Feb 2025 23:53:23 -0800 Subject: [PATCH 19/30] feat: progress on axum-multi-tenant example --- Cargo.lock | 20 +-- .../postgres/axum-multi-tenant/Cargo.toml | 2 +- .../axum-multi-tenant/accounts/Cargo.toml | 6 +- .../accounts/migrations/02_account.sql | 2 +- .../axum-multi-tenant/accounts/sqlx.toml | 1 + .../axum-multi-tenant/accounts/src/lib.rs | 151 +++++++++++++----- .../axum-multi-tenant/payments/Cargo.toml | 2 +- 7 files changed, 128 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 776676c312..fa0eda5741 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,8 @@ name = "accounts" version = "0.1.0" dependencies = [ "argon2 0.5.3", + "password-hash 0.5.0", + "rand", "sqlx", "thiserror 1.0.69", "tokio", @@ -445,6 +447,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "axum-multi-tenant" +version = "0.8.3" +dependencies = [ + "accounts", + "payments", + "sqlx", +] + [[package]] name = "backoff" version = "0.4.0" @@ -2324,15 +2335,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "multi-tenant" -version = "0.8.3" -dependencies = [ - "accounts", - "payments", - "sqlx", -] - [[package]] name = "native-tls" version = "0.2.12" diff --git a/examples/postgres/axum-multi-tenant/Cargo.toml b/examples/postgres/axum-multi-tenant/Cargo.toml index 1be607c5b8..5d3b7167c3 100644 --- a/examples/postgres/axum-multi-tenant/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "multi-tenant" +name = "axum-multi-tenant" version.workspace = true license.workspace = true edition.workspace = true diff --git a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml index 485ba8eb73..bc414e0b33 100644 --- a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml @@ -4,10 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] -sqlx = { workspace = true, features = ["postgres", "time"] } -argon2 = { version = "0.5.3", features = ["password-hash"] } +sqlx = { workspace = true, features = ["postgres", "time", "uuid"] } tokio = { version = "1", features = ["rt", "sync"] } +argon2 = { version = "0.5.3", features = ["password-hash"] } +password-hash = { version = "0.5", features = ["std"] } + uuid = "1" thiserror = "1" rand = "0.8" diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql index 91b9cf82e0..ea9b8b9531 100644 --- a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql +++ b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql @@ -1,4 +1,4 @@ -create table account +create table accounts.account ( account_id uuid primary key default gen_random_uuid(), email text unique not null, diff --git a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml index 45042f1333..8ce3f3f5e0 100644 --- a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml +++ b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml @@ -4,3 +4,4 @@ migrations-table = "accounts._sqlx_migrations" [macros.table-overrides.'accounts.account'] 'account_id' = "crate::AccountId" +'password_hash' = "sqlx::types::Text" diff --git a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs index f015af3d40..5535564e0c 100644 --- a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs +++ b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs @@ -1,5 +1,6 @@ +use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier}; use std::error::Error; -use argon2::{password_hash, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; +use std::sync::Arc; use password_hash::PasswordHashString; @@ -10,21 +11,24 @@ use uuid::Uuid; use tokio::sync::Semaphore; -#[derive(sqlx::Type)] +#[derive(sqlx::Type, Debug)] #[sqlx(transparent)] pub struct AccountId(pub Uuid); - pub struct AccountsManager { - hashing_semaphore: Semaphore, + hashing_semaphore: Arc, } #[derive(Debug, thiserror::Error)] pub enum CreateError { - #[error("email in-use")] + #[error("error creating account: email in-use")] EmailInUse, - General(#[source] - #[from] GeneralError), + #[error("error creating account")] + General( + #[source] + #[from] + GeneralError, + ), } #[derive(Debug, thiserror::Error)] @@ -33,50 +37,95 @@ pub enum AuthenticateError { UnknownEmail, #[error("invalid password")] InvalidPassword, - General(#[source] - #[from] GeneralError), + #[error("authentication error")] + General( + #[source] + #[from] + GeneralError, + ), } #[derive(Debug, thiserror::Error)] pub enum GeneralError { - Sqlx(#[source] - #[from] sqlx::Error), - PasswordHash(#[source] #[from] argon2::password_hash::Error), - Task(#[source] - #[from] tokio::task::JoinError), + #[error("database error")] + Sqlx( + #[source] + #[from] + sqlx::Error, + ), + #[error("error hashing password")] + PasswordHash( + #[source] + #[from] + argon2::password_hash::Error, + ), + #[error("task panicked")] + Task( + #[source] + #[from] + tokio::task::JoinError, + ), } impl AccountsManager { - pub async fn new(conn: &mut PgConnection, max_hashing_threads: usize) -> Result { - sqlx::migrate!().run(conn).await?; + pub async fn new( + conn: &mut PgConnection, + max_hashing_threads: usize, + ) -> Result { + sqlx::migrate!() + .run(conn) + .await + .map_err(sqlx::Error::from)?; - AccountsManager { - hashing_semaphore: Semaphore::new(max_hashing_threads) - } + Ok(AccountsManager { + hashing_semaphore: Semaphore::new(max_hashing_threads).into(), + }) } - async fn hash_password(&self, password: String) -> Result { - let guard = self.hashing_semaphore.acquire().await + async fn hash_password(&self, password: String) -> Result { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await .expect("BUG: this semaphore should not be closed"); // We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn // excess threads. let (_guard, res) = tokio::task::spawn_blocking(move || { let salt = argon2::password_hash::SaltString::generate(rand::thread_rng()); - (guard, Argon2::default().hash_password(password.as_bytes(), &salt)) + ( + guard, + Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map(|hash| hash.serialize()), + ) }) - .await?; + .await?; Ok(res?) } - async fn verify_password(&self, password: String, hash: PasswordHashString) -> Result<(), AuthenticateError> { - let guard = self.hashing_semaphore.acquire().await + async fn verify_password( + &self, + password: String, + hash: PasswordHashString, + ) -> Result<(), AuthenticateError> { + let guard = self + .hashing_semaphore + .clone() + .acquire_owned() + .await .expect("BUG: this semaphore should not be closed"); let (_guard, res) = tokio::task::spawn_blocking(move || { - (guard, Argon2::default().verify_password(password.as_bytes(), &hash.password_hash())) - }).await.map_err(GeneralError::from)?; + ( + guard, + Argon2::default().verify_password(password.as_bytes(), &hash.password_hash()), + ) + }) + .await + .map_err(GeneralError::from)?; if let Err(password_hash::Error::Password) = res { return Err(AuthenticateError::InvalidPassword); @@ -87,46 +136,64 @@ impl AccountsManager { Ok(()) } - pub async fn create(&self, txn: &mut PgTransaction, email: &str, password: String) -> Result { + pub async fn create( + &self, + txn: &mut PgTransaction<'_>, + email: &str, + password: String, + ) -> Result { // Hash password whether the account exists or not to make it harder // to tell the difference in the timing. let hash = self.hash_password(password).await?; + // Thanks to `sqlx.toml`, `account_id` maps to `AccountId` // language=PostgreSQL - sqlx::query!( + sqlx::query_scalar!( "insert into accounts.account(email, password_hash) \ values ($1, $2) \ returning account_id", email, - Text(hash) as Text>, + hash.as_str(), ) - .fetch_one(&mut *txn) - .await - .map_err(|e| if e.constraint() == Some("account_account_id_key") { + .fetch_one(&mut **txn) + .await + .map_err(|e| { + if e.as_database_error().and_then(|dbe| dbe.constraint()) == Some("account_account_id_key") { CreateError::EmailInUse } else { GeneralError::from(e).into() - }) + } + }) } - pub async fn authenticate(&self, conn: &mut PgConnection, email: &str, password: String) -> Result { + pub async fn authenticate( + &self, + conn: &mut PgConnection, + email: &str, + password: String, + ) -> Result { + // Thanks to `sqlx.toml`: + // * `account_id` maps to `AccountId` + // * `password_hash` maps to `Text` let maybe_account = sqlx::query!( - "select account_id, password_hash as \"password_hash: Text\" \ + "select account_id, password_hash \ from accounts.account \ - where email_id = $1", + where email = $1", email ) - .fetch_optional(&mut *conn) - .await - .map_err(GeneralError::from)?; + .fetch_optional(&mut *conn) + .await + .map_err(GeneralError::from)?; let Some(account) = maybe_account else { // Hash the password whether the account exists or not to hide the difference in timing. - self.hash_password(password).await.map_err(GeneralError::from)?; + self.hash_password(password) + .await + .map_err(GeneralError::from)?; return Err(AuthenticateError::UnknownEmail); }; - self.verify_password(password, account.password_hash.into())?; + self.verify_password(password, account.password_hash.into_inner()).await?; Ok(account.account_id) } diff --git a/examples/postgres/axum-multi-tenant/payments/Cargo.toml b/examples/postgres/axum-multi-tenant/payments/Cargo.toml index 0a2485955b..d7dc430553 100644 --- a/examples/postgres/axum-multi-tenant/payments/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/payments/Cargo.toml @@ -4,4 +4,4 @@ version = "0.1.0" edition = "2021" [dependencies] -sqlx = { workspace = true, features = ["postgres", "time"] } +sqlx = { workspace = true, features = ["postgres", "time", "uuid"] } From d9fc4899ad633489d8f4e398aa39fc0dadc5b6fa Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 21 Feb 2025 15:52:27 -0800 Subject: [PATCH 20/30] feat(config): better errors for mislabeled fields --- sqlx-core/src/config/common.rs | 2 +- sqlx-core/src/config/macros.rs | 2 +- sqlx-core/src/config/migrate.rs | 2 +- sqlx-core/src/config/mod.rs | 5 ++++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/config/common.rs b/sqlx-core/src/config/common.rs index d2bf639e5f..2d5342d5b8 100644 --- a/sqlx-core/src/config/common.rs +++ b/sqlx-core/src/config/common.rs @@ -3,7 +3,7 @@ #[cfg_attr( feature = "sqlx-toml", derive(serde::Deserialize), - serde(default, rename_all = "kebab-case") + serde(default, rename_all = "kebab-case", deny_unknown_fields) )] pub struct Config { /// Override the database URL environment variable. diff --git a/sqlx-core/src/config/macros.rs b/sqlx-core/src/config/macros.rs index 19e5f42fa0..9acabf2d6a 100644 --- a/sqlx-core/src/config/macros.rs +++ b/sqlx-core/src/config/macros.rs @@ -5,7 +5,7 @@ use std::collections::BTreeMap; #[cfg_attr( feature = "sqlx-toml", derive(serde::Deserialize), - serde(default, rename_all = "kebab-case") + serde(default, rename_all = "kebab-case", deny_unknown_fields) )] pub struct Config { /// Specify which crates' types to use when types from multiple crates apply. diff --git a/sqlx-core/src/config/migrate.rs b/sqlx-core/src/config/migrate.rs index 4865e24c76..0dd6cc2257 100644 --- a/sqlx-core/src/config/migrate.rs +++ b/sqlx-core/src/config/migrate.rs @@ -16,7 +16,7 @@ use std::collections::BTreeSet; #[cfg_attr( feature = "sqlx-toml", derive(serde::Deserialize), - serde(default, rename_all = "kebab-case") + serde(default, rename_all = "kebab-case", deny_unknown_fields) )] pub struct Config { /// Specify the names of schemas to create if they don't already exist. diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index 02bde20f73..5801af888c 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -48,7 +48,7 @@ mod tests; #[cfg_attr( feature = "sqlx-toml", derive(serde::Deserialize), - serde(default, rename_all = "kebab-case") + serde(default, rename_all = "kebab-case", deny_unknown_fields) )] pub struct Config { /// Configuration shared by multiple components. @@ -210,6 +210,9 @@ impl Config { // Only returned if the file exists but the feature is not enabled. panic!("{e}") } + Err(ConfigError::Parse { error, path }) => { + panic!("error parsing sqlx config {path:?}: {error}") + } Err(e) => { panic!("failed to read sqlx config: {e}") } From 0b79b51d743b4e3249a9f10dfcf50d3847d3a68a Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 26 Feb 2025 13:36:53 -0800 Subject: [PATCH 21/30] WIP feat: filling out axum-multi-tenant example --- Cargo.lock | 378 ++++++++++++++++-- .../postgres/axum-multi-tenant/Cargo.toml | 10 + examples/postgres/axum-multi-tenant/README.md | 19 +- .../axum-multi-tenant/accounts/Cargo.toml | 7 +- .../accounts/migrations/01_setup.sql | 30 ++ .../accounts/migrations/02_account.sql | 12 +- .../axum-multi-tenant/accounts/sqlx.toml | 2 +- .../axum-multi-tenant/accounts/src/lib.rs | 53 ++- .../axum-multi-tenant/payments/Cargo.toml | 12 +- .../payments/migrations/01_setup.sql | 30 ++ .../payments/migrations/02_payment.sql | 58 +++ .../axum-multi-tenant/payments/sqlx.toml | 10 + .../axum-multi-tenant/payments/src/lib.rs | 40 +- .../axum-multi-tenant/src/http/mod.rs | 7 + .../postgres/axum-multi-tenant/src/main.rs | 65 ++- sqlx-macros-core/src/migrate.rs | 17 +- 16 files changed, 664 insertions(+), 86 deletions(-) create mode 100644 examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql create mode 100644 examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql create mode 100644 examples/postgres/axum-multi-tenant/payments/sqlx.toml create mode 100644 examples/postgres/axum-multi-tenant/src/http/mod.rs diff --git a/Cargo.lock b/Cargo.lock index fa0eda5741..6aec9a98f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,24 +11,25 @@ dependencies = [ "rand", "sqlx", "thiserror 1.0.69", + "time", "tokio", "uuid", ] [[package]] name = "addr2line" -version = "0.24.2" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] [[package]] -name = "adler2" -version = "2.0.0" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" @@ -394,16 +395,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.2.9", "axum-macros", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", "itoa", - "matchit", + "matchit 0.5.0", "memchr", "mime", "percent-encoding", @@ -411,14 +412,48 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "tokio", - "tower", + "tower 0.4.13", "tower-http", "tower-layer", "tower-service", ] +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.2.9" @@ -428,13 +463,33 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "mime", "tower-layer", "tower-service", ] +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-macros" version = "0.2.3" @@ -452,8 +507,14 @@ name = "axum-multi-tenant" version = "0.8.3" dependencies = [ "accounts", + "axum 0.8.1", + "clap", + "color-eyre", + "dotenvy", "payments", "sqlx", + "tokio", + "tracing-subscriber", ] [[package]] @@ -472,17 +533,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", + "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", - "windows-targets 0.52.6", ] [[package]] @@ -821,9 +882,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" dependencies = [ "clap_builder", "clap_derive", @@ -831,9 +892,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" dependencies = [ "anstream", "anstyle", @@ -852,9 +913,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.24" +version = "4.5.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" +checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -888,6 +949,33 @@ dependencies = [ "cc", ] +[[package]] +name = "color-eyre" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -1346,6 +1434,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -1609,9 +1707,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.1" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1738,6 +1836,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1745,7 +1854,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.2.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1783,8 +1915,8 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1796,6 +1928,41 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "hyper 1.6.0", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1980,6 +2147,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -2230,6 +2403,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -2278,11 +2457,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ - "adler2", + "adler", ] [[package]] @@ -2403,6 +2582,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -2468,9 +2657,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.7" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -2541,6 +2730,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking" version = "2.2.1" @@ -2602,7 +2803,11 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" name = "payments" version = "0.1.0" dependencies = [ + "accounts", + "rust_decimal", "sqlx", + "time", + "uuid", ] [[package]] @@ -3374,6 +3579,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -3445,6 +3660,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -3676,7 +3900,7 @@ version = "0.1.0" dependencies = [ "anyhow", "argon2 0.4.1", - "axum", + "axum 0.5.17", "dotenvy", "once_cell", "rand", @@ -3688,7 +3912,7 @@ dependencies = [ "thiserror 2.0.11", "time", "tokio", - "tower", + "tower 0.4.13", "tracing", "uuid", "validator", @@ -4127,6 +4351,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.1" @@ -4219,6 +4449,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.37" @@ -4375,6 +4615,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.3.5" @@ -4385,11 +4641,11 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "http-range-header", "pin-project-lite", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", ] @@ -4436,6 +4692,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-error" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" +dependencies = [ + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -4564,9 +4856,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.1" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ "serde", ] @@ -4613,6 +4905,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "value-bag" version = "1.10.0" diff --git a/examples/postgres/axum-multi-tenant/Cargo.toml b/examples/postgres/axum-multi-tenant/Cargo.toml index 5d3b7167c3..7ea32bbc43 100644 --- a/examples/postgres/axum-multi-tenant/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/Cargo.toml @@ -12,7 +12,17 @@ authors.workspace = true accounts = { path = "accounts" } payments = { path = "payments" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } + sqlx = { path = "../../..", version = "0.8.3", features = ["runtime-tokio", "postgres"] } +axum = "0.8.1" + +clap = { version = "4.5.30", features = ["derive", "env"] } +color-eyre = "0.6.3" +dotenvy = "0.15.7" +tracing-subscriber = "0.3.19" + + [lints] workspace = true diff --git a/examples/postgres/axum-multi-tenant/README.md b/examples/postgres/axum-multi-tenant/README.md index d38f7f3ea5..aae3a6f1fe 100644 --- a/examples/postgres/axum-multi-tenant/README.md +++ b/examples/postgres/axum-multi-tenant/README.md @@ -3,9 +3,20 @@ This example project involves three crates, each owning a different schema in one database, with their own set of migrations. -* The main crate, an Axum app. - * Owns the `public` schema (tables are referenced unqualified). +* The main crate, an Axum app. + * Owns the `public` schema (tables are referenced unqualified). * `accounts`: a subcrate simulating a reusable account-management crate. - * Owns schema `accounts`. + * Owns schema `accounts`. * `payments`: a subcrate simulating a wrapper for a payments API. - * Owns schema `payments`. + * Owns schema `payments`. + +## Note: Schema-Qualified Names + +This example uses schema-qualified names everywhere for clarity. + +It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema +prefixes, but this can cause some really confusing issues when names conflict. + +This example will generate a `_sqlx_migrations` table in three different schemas, and if `search_path` is set +to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified, +it would throw an error. diff --git a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml index bc414e0b33..dd95a890af 100644 --- a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/accounts/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -sqlx = { workspace = true, features = ["postgres", "time", "uuid"] } +sqlx = { workspace = true, features = ["postgres", "time", "uuid", "macros", "sqlx-toml"] } tokio = { version = "1", features = ["rt", "sync"] } argon2 = { version = "0.5.3", features = ["password-hash"] } @@ -13,3 +13,8 @@ password-hash = { version = "0.5", features = ["std"] } uuid = "1" thiserror = "1" rand = "0.8" + +time = "0.3.37" + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql b/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql index e69de29bb2..5aa8fa23cf 100644 --- a/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql +++ b/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select accounts.trigger_updated_at(''); +-- +-- after a `CREATE TABLE`. +create or replace function accounts.set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); +return NEW; +end; +$$ language plpgsql; + +create or replace function accounts.trigger_updated_at(tablename regclass) + returns void as +$$ +begin +execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION accounts.set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql index ea9b8b9531..a75814bd09 100644 --- a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql +++ b/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql @@ -1,8 +1,10 @@ create table accounts.account ( - account_id uuid primary key default gen_random_uuid(), - email text unique not null, - password_hash text not null, - created_at timestamptz not null default now(), - updated_at timestamptz + account_id uuid primary key default gen_random_uuid(), + email text unique not null, + password_hash text not null, + created_at timestamptz not null default now(), + updated_at timestamptz ); + +select accounts.trigger_updated_at('accounts.account'); diff --git a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml index 8ce3f3f5e0..1d02130c2d 100644 --- a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml +++ b/examples/postgres/axum-multi-tenant/accounts/sqlx.toml @@ -1,6 +1,6 @@ [migrate] create-schemas = ["accounts"] -migrations-table = "accounts._sqlx_migrations" +table-name = "accounts._sqlx_migrations" [macros.table-overrides.'accounts.account'] 'account_id' = "crate::AccountId" diff --git a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs index 5535564e0c..3037463e4c 100644 --- a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs +++ b/examples/postgres/axum-multi-tenant/accounts/src/lib.rs @@ -1,11 +1,9 @@ use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier}; -use std::error::Error; use std::sync::Arc; use password_hash::PasswordHashString; -use sqlx::{PgConnection, PgTransaction}; -use sqlx::types::Text; +use sqlx::{PgConnection, PgPool, PgTransaction}; use uuid::Uuid; @@ -16,6 +14,37 @@ use tokio::sync::Semaphore; pub struct AccountId(pub Uuid); pub struct AccountsManager { + /// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing. + /// + /// ### Motivation + /// Tokio blocking tasks are generally not designed for CPU-bound work. + /// + /// If no threads are idle, Tokio will automatically spawn new ones to handle + /// new blocking tasks up to a very high limit--512 by default. + /// + /// This is because blocking tasks are expected to spend their time *blocked*, e.g. on + /// blocking I/O, and thus not consume CPU resources or require a lot of context switching. + /// + /// This strategy is not the most efficient way to use threads for CPU-bound work, which + /// should schedule work to a fixed number of threads to minimize context switching + /// and memory usage (each new thread needs significant space allocated for its stack). + /// + /// We can work around this by using a purpose-designed thread-pool, like Rayon, + /// but we still have the problem that those APIs usually are not designed to support `async`, + /// so we end up needing blocking tasks anyway, or implementing our own work queue using + /// channels. Rayon also does not shut down idle worker threads. + /// + /// `block_in_place` is not a silver bullet, either, as it simply uses `spawn_blocking` + /// internally to take over from the current thread while it is executing blocking work. + /// This also prevents futures from being polled concurrently in the current task. + /// + /// We can lower the limit for blocking threads when creating the runtime, but this risks + /// starving other blocking tasks that are being created by the application or the Tokio + /// runtime itself + /// (which are used for `tokio::fs`, stdio, resolving of hostnames by `ToSocketAddrs`, etc.). + /// + /// Instead, we can just use a Semaphore to limit how many blocking tasks are spawned at once, + /// emulating the behavior of a thread pool like Rayon without needing any additional crates. hashing_semaphore: Arc, } @@ -57,7 +86,7 @@ pub enum GeneralError { PasswordHash( #[source] #[from] - argon2::password_hash::Error, + password_hash::Error, ), #[error("task panicked")] Task( @@ -68,12 +97,9 @@ pub enum GeneralError { } impl AccountsManager { - pub async fn new( - conn: &mut PgConnection, - max_hashing_threads: usize, - ) -> Result { + pub async fn setup(pool: &PgPool, max_hashing_threads: usize) -> Result { sqlx::migrate!() - .run(conn) + .run(pool) .await .map_err(sqlx::Error::from)?; @@ -147,8 +173,8 @@ impl AccountsManager { let hash = self.hash_password(password).await?; // Thanks to `sqlx.toml`, `account_id` maps to `AccountId` - // language=PostgreSQL sqlx::query_scalar!( + // language=PostgreSQL "insert into accounts.account(email, password_hash) \ values ($1, $2) \ returning account_id", @@ -158,7 +184,9 @@ impl AccountsManager { .fetch_one(&mut **txn) .await .map_err(|e| { - if e.as_database_error().and_then(|dbe| dbe.constraint()) == Some("account_account_id_key") { + if e.as_database_error().and_then(|dbe| dbe.constraint()) + == Some("account_account_id_key") + { CreateError::EmailInUse } else { GeneralError::from(e).into() @@ -193,7 +221,8 @@ impl AccountsManager { return Err(AuthenticateError::UnknownEmail); }; - self.verify_password(password, account.password_hash.into_inner()).await?; + self.verify_password(password, account.password_hash.into_inner()) + .await?; Ok(account.account_id) } diff --git a/examples/postgres/axum-multi-tenant/payments/Cargo.toml b/examples/postgres/axum-multi-tenant/payments/Cargo.toml index d7dc430553..6a0e4d2672 100644 --- a/examples/postgres/axum-multi-tenant/payments/Cargo.toml +++ b/examples/postgres/axum-multi-tenant/payments/Cargo.toml @@ -4,4 +4,14 @@ version = "0.1.0" edition = "2021" [dependencies] -sqlx = { workspace = true, features = ["postgres", "time", "uuid"] } +accounts = { path = "../accounts" } + +sqlx = { workspace = true, features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml"] } + +rust_decimal = "1.36.0" + +time = "0.3.37" +uuid = "1.12.1" + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql b/examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql new file mode 100644 index 0000000000..4935a63705 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select payments.trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function payments.set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); +return NEW; +end; +$$ language plpgsql; + +create or replace function payments.trigger_updated_at(tablename regclass) + returns void as +$$ +begin +execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION payments.set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql b/examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql new file mode 100644 index 0000000000..cc372f01b7 --- /dev/null +++ b/examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql @@ -0,0 +1,58 @@ +-- `payments::PaymentStatus` +-- +-- Historically at LaunchBadge we preferred not to define enums on the database side because it can be annoying +-- and error-prone to keep them in-sync with the application. +-- Instead, we let the application define the enum and just have the database store a compact representation of it. +-- This is mostly a matter of taste, however. +-- +-- For the purposes of this example, we're using an in-database enum because this is a common use-case +-- for needing type overrides. +create type payments.payment_status as enum ( + 'pending', + 'success', + 'failed' + ); + +create table payments.payment +( + payment_id uuid primary key default gen_random_uuid(), + -- This cross-schema reference means migrations for the `accounts` crate should be run first. + account_id uuid not null references accounts.account (account_id), + + status payments.payment_status NOT NULL, + + -- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes) + -- + -- This *could* be an ENUM of currency codes, but constraining this to a set of known values in the database + -- would be annoying to keep up to date as support for more currencies is added. + -- + -- Consider also if support for cryptocurrencies is desired; those are not covered by ISO 4217. + -- + -- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)` + -- all use the same storage format in Postgres. Any constraint against the length of this field + -- would purely be a sanity check. + currency text NOT NULL, + -- There's an endless debate about what type should be used to represent currency amounts. + -- + -- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly + -- optimized for storing USD, or other currencies with a minimum fraction of 1 cent. + -- + -- NEVER use `FLOAT` or `DOUBLE`. IEEE-754 rounding point has round-off and precision errors that make it wholly + -- unsuitable for representing real money amounts. + -- + -- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency, + -- and so is what we've chosen here. + amount NUMERIC NOT NULL, + + -- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.), + -- so imagine this is an identifier string for this payment in such a vendor's systems. + -- + -- For privacy and security reasons, payment and personally-identifying information + -- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor + -- unless there is a good reason otherwise. + external_payment_id TEXT NOT NULL UNIQUE, + created_at timestamptz default now(), + updated_at timestamptz +); + +select payments.trigger_updated_at('payments.payment'); diff --git a/examples/postgres/axum-multi-tenant/payments/sqlx.toml b/examples/postgres/axum-multi-tenant/payments/sqlx.toml new file mode 100644 index 0000000000..1a4a27dc6a --- /dev/null +++ b/examples/postgres/axum-multi-tenant/payments/sqlx.toml @@ -0,0 +1,10 @@ +[migrate] +create-schemas = ["payments"] +table-name = "payments._sqlx_migrations" + +[macros.table-overrides.'payments.payment'] +'payment_id' = "crate::PaymentId" +'account_id' = "accounts::AccountId" + +[macros.type-overrides] +'payments.payment_status' = "crate::PaymentStatus" diff --git a/examples/postgres/axum-multi-tenant/payments/src/lib.rs b/examples/postgres/axum-multi-tenant/payments/src/lib.rs index 7d12d9af81..b0efcfe17f 100644 --- a/examples/postgres/axum-multi-tenant/payments/src/lib.rs +++ b/examples/postgres/axum-multi-tenant/payments/src/lib.rs @@ -1,14 +1,34 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right +use accounts::AccountId; +use sqlx::PgPool; +use time::OffsetDateTime; +use uuid::Uuid; + +#[derive(sqlx::Type, Debug)] +#[sqlx(transparent)] +pub struct PaymentId(pub Uuid); + +#[derive(sqlx::Type, Debug)] +#[sqlx(type_name = "payments.payment_status")] +#[sqlx(rename_all = "snake_case")] +pub enum PaymentStatus { + Pending, + Successful, } -#[cfg(test)] -mod tests { - use super::*; +#[derive(Debug)] +pub struct Payment { + pub payment_id: PaymentId, + pub account_id: AccountId, + pub status: PaymentStatus, + pub currency: String, + // `rust_decimal::Decimal` has more than enough precision for any real-world amount of money. + pub amount: rust_decimal::Decimal, + pub external_payment_id: String, + pub created_at: OffsetDateTime, + pub updated_at: Option, +} - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } +pub async fn migrate(pool: &PgPool) -> sqlx::Result<()> { + sqlx::migrate!().run(pool).await?; + Ok(()) } diff --git a/examples/postgres/axum-multi-tenant/src/http/mod.rs b/examples/postgres/axum-multi-tenant/src/http/mod.rs new file mode 100644 index 0000000000..9197a2042f --- /dev/null +++ b/examples/postgres/axum-multi-tenant/src/http/mod.rs @@ -0,0 +1,7 @@ +use accounts::AccountsManager; +use color_eyre::eyre; +use sqlx::PgPool; + +pub async fn run(pool: PgPool, accounts: AccountsManager) -> eyre::Result<()> { + axum::serve +} diff --git a/examples/postgres/axum-multi-tenant/src/main.rs b/examples/postgres/axum-multi-tenant/src/main.rs index e7a11a969c..3d4b0cba64 100644 --- a/examples/postgres/axum-multi-tenant/src/main.rs +++ b/examples/postgres/axum-multi-tenant/src/main.rs @@ -1,3 +1,64 @@ -fn main() { - println!("Hello, world!"); +mod http; + +use accounts::AccountsManager; +use color_eyre::eyre; +use color_eyre::eyre::Context; + +#[derive(clap::Parser)] +struct Args { + #[clap(long, env)] + database_url: String, + + #[clap(long, env, default_value_t = 0)] + max_hashing_threads: usize, +} + +#[tokio::main] +async fn main() -> eyre::Result<()> { + color_eyre::install()?; + let _ = dotenvy::dotenv(); + + // (@abonander) I prefer to keep `clap::Parser` fully qualified here because it makes it clear + // what crate the derive macro is coming from. Otherwise, it requires contextual knowledge + // to understand that this is parsing CLI arguments. + let args: Args = clap::Parser::parse(); + + tracing_subscriber::fmt::init(); + + let pool = sqlx::PgPool::connect( + // `env::var()` doesn't include the variable name for context like it should. + &dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?, + ) + .await + .wrap_err("could not connect to database")?; + + let max_hashing_threads = if args.max_hashing_threads == 0 { + std::thread::available_parallelism() + // We could just default to 1 but that would be a silent pessimization, + // which would be hard to debug. + .wrap_err("unable to determine number of available CPU cores; set `--max-hashing-threads` to a nonzero amount")? + .get() + } else { + args.max_hashing_threads + }; + + // Runs migration for `accounts` internally. + let accounts = AccountsManager::setup(&pool, max_hashing_threads) + .await + .wrap_err("error initializing AccountsManager")?; + + payments::migrate(&pool) + .await + .wrap_err("error running payments migrations")?; + + // `main()` doesn't actually run from a Tokio worker thread, + // so spawned tasks hit the global injection queue first and communication with the driver + // core is always cross-thread. + // + // The recommendation is to spawn the `axum::serve` future as a task so it executes directly + // on a worker thread. + + let http_task = tokio::spawn(http::run(pool, accounts)); + + Ok(()) } diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index 2f0e92bc88..729d61ce91 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -118,6 +118,12 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result crate::Result Date: Thu, 27 Feb 2025 16:20:09 -0800 Subject: [PATCH 22/30] feat: multi-tenant example No longer Axum-based because filling out the request routes would have distracted from the purpose of the example. --- Cargo.lock | 26 ++++- Cargo.toml | 2 +- .../axum-multi-tenant/payments/src/lib.rs | 34 ------ .../axum-multi-tenant/src/http/mod.rs | 7 -- .../postgres/axum-multi-tenant/src/main.rs | 64 ---------- .../Cargo.toml | 6 +- .../README.md | 26 +++++ .../accounts/Cargo.toml | 6 +- .../accounts/migrations/01_setup.sql | 4 +- .../accounts/migrations/02_account.sql | 0 .../accounts/migrations/03_session.sql | 6 + .../accounts/sqlx.toml | 4 + .../accounts/src/lib.rs | 103 ++++++++++++---- .../payments/Cargo.toml | 0 .../payments/migrations/01_setup.sql | 0 .../payments/migrations/02_payment.sql | 13 ++- .../payments/sqlx.toml | 0 .../postgres/multi-tenant/payments/src/lib.rs | 110 ++++++++++++++++++ examples/postgres/multi-tenant/sqlx.toml | 3 + examples/postgres/multi-tenant/src/main.rs | 105 +++++++++++++++++ .../multi-tenant/src/migrations/01_setup.sql | 30 +++++ .../src/migrations/02_purchase.sql | 11 ++ 22 files changed, 412 insertions(+), 148 deletions(-) delete mode 100644 examples/postgres/axum-multi-tenant/payments/src/lib.rs delete mode 100644 examples/postgres/axum-multi-tenant/src/http/mod.rs delete mode 100644 examples/postgres/axum-multi-tenant/src/main.rs rename examples/postgres/{axum-multi-tenant => multi-tenant}/Cargo.toml (85%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/README.md (51%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/accounts/Cargo.toml (74%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/accounts/migrations/01_setup.sql (92%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/accounts/migrations/02_account.sql (100%) create mode 100644 examples/postgres/multi-tenant/accounts/migrations/03_session.sql rename examples/postgres/{axum-multi-tenant => multi-tenant}/accounts/sqlx.toml (66%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/accounts/src/lib.rs (69%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/payments/Cargo.toml (100%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/payments/migrations/01_setup.sql (100%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/payments/migrations/02_payment.sql (87%) rename examples/postgres/{axum-multi-tenant => multi-tenant}/payments/sqlx.toml (100%) create mode 100644 examples/postgres/multi-tenant/payments/src/lib.rs create mode 100644 examples/postgres/multi-tenant/sqlx.toml create mode 100644 examples/postgres/multi-tenant/src/main.rs create mode 100644 examples/postgres/multi-tenant/src/migrations/01_setup.sql create mode 100644 examples/postgres/multi-tenant/src/migrations/02_purchase.sql diff --git a/Cargo.lock b/Cargo.lock index 6aec9a98f1..571a7dbc78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ dependencies = [ "argon2 0.5.3", "password-hash 0.5.0", "rand", + "serde", "sqlx", "thiserror 1.0.69", "time", @@ -396,7 +397,7 @@ checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", "axum-core 0.2.9", - "axum-macros", + "axum-macros 0.2.3", "bitflags 1.3.2", "bytes", "futures-util", @@ -427,6 +428,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ "axum-core 0.5.0", + "axum-macros 0.5.0", "bytes", "form_urlencoded", "futures-util", @@ -502,16 +504,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "axum-multi-tenant" version = "0.8.3" dependencies = [ "accounts", "axum 0.8.1", - "clap", "color-eyre", "dotenvy", "payments", + "rand", + "rust_decimal", "sqlx", "tokio", "tracing-subscriber", @@ -3540,18 +3554,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index b75b3c3bd7..3b4d263834 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "sqlx-postgres", "sqlx-sqlite", "examples/mysql/todos", - "examples/postgres/axum-multi-tenant", + "examples/postgres/multi-tenant", "examples/postgres/axum-social-with-tests", "examples/postgres/chat", "examples/postgres/files", diff --git a/examples/postgres/axum-multi-tenant/payments/src/lib.rs b/examples/postgres/axum-multi-tenant/payments/src/lib.rs deleted file mode 100644 index b0efcfe17f..0000000000 --- a/examples/postgres/axum-multi-tenant/payments/src/lib.rs +++ /dev/null @@ -1,34 +0,0 @@ -use accounts::AccountId; -use sqlx::PgPool; -use time::OffsetDateTime; -use uuid::Uuid; - -#[derive(sqlx::Type, Debug)] -#[sqlx(transparent)] -pub struct PaymentId(pub Uuid); - -#[derive(sqlx::Type, Debug)] -#[sqlx(type_name = "payments.payment_status")] -#[sqlx(rename_all = "snake_case")] -pub enum PaymentStatus { - Pending, - Successful, -} - -#[derive(Debug)] -pub struct Payment { - pub payment_id: PaymentId, - pub account_id: AccountId, - pub status: PaymentStatus, - pub currency: String, - // `rust_decimal::Decimal` has more than enough precision for any real-world amount of money. - pub amount: rust_decimal::Decimal, - pub external_payment_id: String, - pub created_at: OffsetDateTime, - pub updated_at: Option, -} - -pub async fn migrate(pool: &PgPool) -> sqlx::Result<()> { - sqlx::migrate!().run(pool).await?; - Ok(()) -} diff --git a/examples/postgres/axum-multi-tenant/src/http/mod.rs b/examples/postgres/axum-multi-tenant/src/http/mod.rs deleted file mode 100644 index 9197a2042f..0000000000 --- a/examples/postgres/axum-multi-tenant/src/http/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -use accounts::AccountsManager; -use color_eyre::eyre; -use sqlx::PgPool; - -pub async fn run(pool: PgPool, accounts: AccountsManager) -> eyre::Result<()> { - axum::serve -} diff --git a/examples/postgres/axum-multi-tenant/src/main.rs b/examples/postgres/axum-multi-tenant/src/main.rs deleted file mode 100644 index 3d4b0cba64..0000000000 --- a/examples/postgres/axum-multi-tenant/src/main.rs +++ /dev/null @@ -1,64 +0,0 @@ -mod http; - -use accounts::AccountsManager; -use color_eyre::eyre; -use color_eyre::eyre::Context; - -#[derive(clap::Parser)] -struct Args { - #[clap(long, env)] - database_url: String, - - #[clap(long, env, default_value_t = 0)] - max_hashing_threads: usize, -} - -#[tokio::main] -async fn main() -> eyre::Result<()> { - color_eyre::install()?; - let _ = dotenvy::dotenv(); - - // (@abonander) I prefer to keep `clap::Parser` fully qualified here because it makes it clear - // what crate the derive macro is coming from. Otherwise, it requires contextual knowledge - // to understand that this is parsing CLI arguments. - let args: Args = clap::Parser::parse(); - - tracing_subscriber::fmt::init(); - - let pool = sqlx::PgPool::connect( - // `env::var()` doesn't include the variable name for context like it should. - &dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?, - ) - .await - .wrap_err("could not connect to database")?; - - let max_hashing_threads = if args.max_hashing_threads == 0 { - std::thread::available_parallelism() - // We could just default to 1 but that would be a silent pessimization, - // which would be hard to debug. - .wrap_err("unable to determine number of available CPU cores; set `--max-hashing-threads` to a nonzero amount")? - .get() - } else { - args.max_hashing_threads - }; - - // Runs migration for `accounts` internally. - let accounts = AccountsManager::setup(&pool, max_hashing_threads) - .await - .wrap_err("error initializing AccountsManager")?; - - payments::migrate(&pool) - .await - .wrap_err("error running payments migrations")?; - - // `main()` doesn't actually run from a Tokio worker thread, - // so spawned tasks hit the global injection queue first and communication with the driver - // core is always cross-thread. - // - // The recommendation is to spawn the `axum::serve` future as a task so it executes directly - // on a worker thread. - - let http_task = tokio::spawn(http::run(pool, accounts)); - - Ok(()) -} diff --git a/examples/postgres/axum-multi-tenant/Cargo.toml b/examples/postgres/multi-tenant/Cargo.toml similarity index 85% rename from examples/postgres/axum-multi-tenant/Cargo.toml rename to examples/postgres/multi-tenant/Cargo.toml index 7ea32bbc43..f7dca28855 100644 --- a/examples/postgres/axum-multi-tenant/Cargo.toml +++ b/examples/postgres/multi-tenant/Cargo.toml @@ -16,13 +16,15 @@ tokio = { version = "1", features = ["rt-multi-thread", "macros"] } sqlx = { path = "../../..", version = "0.8.3", features = ["runtime-tokio", "postgres"] } -axum = "0.8.1" +axum = { version = "0.8.1", features = ["macros"] } -clap = { version = "4.5.30", features = ["derive", "env"] } color-eyre = "0.6.3" dotenvy = "0.15.7" tracing-subscriber = "0.3.19" +rust_decimal = "1.36.0" + +rand = "0.8.5" [lints] workspace = true diff --git a/examples/postgres/axum-multi-tenant/README.md b/examples/postgres/multi-tenant/README.md similarity index 51% rename from examples/postgres/axum-multi-tenant/README.md rename to examples/postgres/multi-tenant/README.md index aae3a6f1fe..9f96ff72f1 100644 --- a/examples/postgres/axum-multi-tenant/README.md +++ b/examples/postgres/multi-tenant/README.md @@ -5,6 +5,8 @@ with their own set of migrations. * The main crate, an Axum app. * Owns the `public` schema (tables are referenced unqualified). + * Migrations are moved to `src/migrations` using config key `migrate.migrations-dir` + to visually separate them from the subcrate folders. * `accounts`: a subcrate simulating a reusable account-management crate. * Owns schema `accounts`. * `payments`: a subcrate simulating a wrapper for a payments API. @@ -20,3 +22,27 @@ prefixes, but this can cause some really confusing issues when names conflict. This example will generate a `_sqlx_migrations` table in three different schemas, and if `search_path` is set to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified, it would throw an error. + +# Setup + +This example requires running three different sets of migrations. + +Ensure `sqlx-cli` is installed with Postgres support. + +Start a Postgres server. + +Create `.env` with `DATABASE_URL` or set it in your shell environment. + +Run the following commands: + +``` +(cd accounts && sqlx db setup) +(cd payments && sqlx migrate run) +sqlx migrate run +``` + +It is an open question how to make this more convenient; `sqlx-cli` could gain a `--recursive` flag that checks +subdirectories for `sqlx.toml` files, but that would only work for crates within the same workspace. If the `accounts` +and `payments` crates were instead crates.io dependencies, we would need Cargo's help to resolve that information. + +An issue has been opened for discussion: diff --git a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml b/examples/postgres/multi-tenant/accounts/Cargo.toml similarity index 74% rename from examples/postgres/axum-multi-tenant/accounts/Cargo.toml rename to examples/postgres/multi-tenant/accounts/Cargo.toml index dd95a890af..0295dcec8a 100644 --- a/examples/postgres/axum-multi-tenant/accounts/Cargo.toml +++ b/examples/postgres/multi-tenant/accounts/Cargo.toml @@ -10,11 +10,13 @@ tokio = { version = "1", features = ["rt", "sync"] } argon2 = { version = "0.5.3", features = ["password-hash"] } password-hash = { version = "0.5", features = ["std"] } -uuid = "1" +uuid = { version = "1", features = ["serde"] } thiserror = "1" rand = "0.8" -time = "0.3.37" +time = { version = "0.3.37", features = ["serde"] } + +serde = { version = "1.0.218", features = ["derive"] } [dev-dependencies] sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql b/examples/postgres/multi-tenant/accounts/migrations/01_setup.sql similarity index 92% rename from examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql rename to examples/postgres/multi-tenant/accounts/migrations/01_setup.sql index 5aa8fa23cf..007e202ec9 100644 --- a/examples/postgres/axum-multi-tenant/accounts/migrations/01_setup.sql +++ b/examples/postgres/multi-tenant/accounts/migrations/01_setup.sql @@ -12,7 +12,7 @@ create or replace function accounts.set_updated_at() $$ begin NEW.updated_at = now(); -return NEW; + return NEW; end; $$ language plpgsql; @@ -20,7 +20,7 @@ create or replace function accounts.trigger_updated_at(tablename regclass) returns void as $$ begin -execute format('CREATE TRIGGER set_updated_at + execute format('CREATE TRIGGER set_updated_at BEFORE UPDATE ON %s FOR EACH ROW diff --git a/examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql b/examples/postgres/multi-tenant/accounts/migrations/02_account.sql similarity index 100% rename from examples/postgres/axum-multi-tenant/accounts/migrations/02_account.sql rename to examples/postgres/multi-tenant/accounts/migrations/02_account.sql diff --git a/examples/postgres/multi-tenant/accounts/migrations/03_session.sql b/examples/postgres/multi-tenant/accounts/migrations/03_session.sql new file mode 100644 index 0000000000..585f425874 --- /dev/null +++ b/examples/postgres/multi-tenant/accounts/migrations/03_session.sql @@ -0,0 +1,6 @@ +create table accounts.session +( + session_token text primary key, -- random alphanumeric string + account_id uuid not null references accounts.account (account_id), + created_at timestamptz not null default now() +); diff --git a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml b/examples/postgres/multi-tenant/accounts/sqlx.toml similarity index 66% rename from examples/postgres/axum-multi-tenant/accounts/sqlx.toml rename to examples/postgres/multi-tenant/accounts/sqlx.toml index 1d02130c2d..024f6395e5 100644 --- a/examples/postgres/axum-multi-tenant/accounts/sqlx.toml +++ b/examples/postgres/multi-tenant/accounts/sqlx.toml @@ -5,3 +5,7 @@ table-name = "accounts._sqlx_migrations" [macros.table-overrides.'accounts.account'] 'account_id' = "crate::AccountId" 'password_hash' = "sqlx::types::Text" + +[macros.table-overrides.'accounts.session'] +'session_token' = "crate::SessionToken" +'account_id' = "crate::AccountId" diff --git a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs b/examples/postgres/multi-tenant/accounts/src/lib.rs similarity index 69% rename from examples/postgres/axum-multi-tenant/accounts/src/lib.rs rename to examples/postgres/multi-tenant/accounts/src/lib.rs index 3037463e4c..ad33735165 100644 --- a/examples/postgres/axum-multi-tenant/accounts/src/lib.rs +++ b/examples/postgres/multi-tenant/accounts/src/lib.rs @@ -1,18 +1,25 @@ use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier}; -use std::sync::Arc; - use password_hash::PasswordHashString; - -use sqlx::{PgConnection, PgPool, PgTransaction}; - +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::{Acquire, Executor, PgTransaction, Postgres}; +use std::sync::Arc; use uuid::Uuid; use tokio::sync::Semaphore; -#[derive(sqlx::Type, Debug)] +#[derive(sqlx::Type, Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] #[sqlx(transparent)] pub struct AccountId(pub Uuid); +#[derive(sqlx::Type, Clone, Debug, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +pub struct SessionToken(pub String); + +pub struct Session { + pub account_id: AccountId, + pub session_token: SessionToken, +} + pub struct AccountsManager { /// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing. /// @@ -49,7 +56,7 @@ pub struct AccountsManager { } #[derive(Debug, thiserror::Error)] -pub enum CreateError { +pub enum CreateAccountError { #[error("error creating account: email in-use")] EmailInUse, #[error("error creating account")] @@ -61,7 +68,7 @@ pub enum CreateError { } #[derive(Debug, thiserror::Error)] -pub enum AuthenticateError { +pub enum CreateSessionError { #[error("unknown email")] UnknownEmail, #[error("invalid password")] @@ -97,7 +104,10 @@ pub enum GeneralError { } impl AccountsManager { - pub async fn setup(pool: &PgPool, max_hashing_threads: usize) -> Result { + pub async fn setup( + pool: impl Acquire<'_, Database = Postgres>, + max_hashing_threads: usize, + ) -> Result { sqlx::migrate!() .run(pool) .await @@ -119,7 +129,7 @@ impl AccountsManager { // We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn // excess threads. let (_guard, res) = tokio::task::spawn_blocking(move || { - let salt = argon2::password_hash::SaltString::generate(rand::thread_rng()); + let salt = password_hash::SaltString::generate(rand::thread_rng()); ( guard, Argon2::default() @@ -136,7 +146,7 @@ impl AccountsManager { &self, password: String, hash: PasswordHashString, - ) -> Result<(), AuthenticateError> { + ) -> Result<(), CreateSessionError> { let guard = self .hashing_semaphore .clone() @@ -154,7 +164,7 @@ impl AccountsManager { .map_err(GeneralError::from)?; if let Err(password_hash::Error::Password) = res { - return Err(AuthenticateError::InvalidPassword); + return Err(CreateSessionError::InvalidPassword); } res.map_err(GeneralError::from)?; @@ -167,7 +177,7 @@ impl AccountsManager { txn: &mut PgTransaction<'_>, email: &str, password: String, - ) -> Result { + ) -> Result { // Hash password whether the account exists or not to make it harder // to tell the difference in the timing. let hash = self.hash_password(password).await?; @@ -187,29 +197,47 @@ impl AccountsManager { if e.as_database_error().and_then(|dbe| dbe.constraint()) == Some("account_account_id_key") { - CreateError::EmailInUse + CreateAccountError::EmailInUse } else { GeneralError::from(e).into() } }) } - pub async fn authenticate( + pub async fn create_session( &self, - conn: &mut PgConnection, + db: impl Acquire<'_, Database = Postgres>, email: &str, password: String, - ) -> Result { + ) -> Result { + let mut txn = db.begin().await.map_err(GeneralError::from)?; + + // To save a round-trip to the database, we'll speculatively insert the session token + // at the same time as we're looking up the password hash. + // + // This does nothing until the transaction is actually committed. + let session_token = SessionToken::generate(); + // Thanks to `sqlx.toml`: // * `account_id` maps to `AccountId` // * `password_hash` maps to `Text` + // * `session_token` maps to `SessionToken` let maybe_account = sqlx::query!( - "select account_id, password_hash \ - from accounts.account \ - where email = $1", - email + // language=PostgreSQL + "with account as ( + select account_id, password_hash \ + from accounts.account \ + where email = $1 + ), session as ( + insert into accounts.session(session_token, account_id) + select $2, account_id + from account + ) + select account.account_id, account.password_hash from account", + email, + session_token.0 ) - .fetch_optional(&mut *conn) + .fetch_optional(&mut *txn) .await .map_err(GeneralError::from)?; @@ -218,12 +246,39 @@ impl AccountsManager { self.hash_password(password) .await .map_err(GeneralError::from)?; - return Err(AuthenticateError::UnknownEmail); + return Err(CreateSessionError::UnknownEmail); }; self.verify_password(password, account.password_hash.into_inner()) .await?; - Ok(account.account_id) + txn.commit().await.map_err(GeneralError::from)?; + + Ok(Session { + account_id: account.account_id, + session_token, + }) + } + + pub async fn auth_session( + &self, + db: impl Executor<'_, Database = Postgres>, + session_token: &str, + ) -> Result, GeneralError> { + sqlx::query_scalar!( + "select account_id from accounts.session where session_token = $1", + session_token + ) + .fetch_optional(db) + .await + .map_err(GeneralError::from) + } +} + +impl SessionToken { + const LEN: usize = 32; + + fn generate() -> Self { + SessionToken(Alphanumeric.sample_string(&mut rand::thread_rng(), Self::LEN)) } } diff --git a/examples/postgres/axum-multi-tenant/payments/Cargo.toml b/examples/postgres/multi-tenant/payments/Cargo.toml similarity index 100% rename from examples/postgres/axum-multi-tenant/payments/Cargo.toml rename to examples/postgres/multi-tenant/payments/Cargo.toml diff --git a/examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql b/examples/postgres/multi-tenant/payments/migrations/01_setup.sql similarity index 100% rename from examples/postgres/axum-multi-tenant/payments/migrations/01_setup.sql rename to examples/postgres/multi-tenant/payments/migrations/01_setup.sql diff --git a/examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql b/examples/postgres/multi-tenant/payments/migrations/02_payment.sql similarity index 87% rename from examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql rename to examples/postgres/multi-tenant/payments/migrations/02_payment.sql index cc372f01b7..ee88fa18c0 100644 --- a/examples/postgres/axum-multi-tenant/payments/migrations/02_payment.sql +++ b/examples/postgres/multi-tenant/payments/migrations/02_payment.sql @@ -9,17 +9,18 @@ -- for needing type overrides. create type payments.payment_status as enum ( 'pending', + 'created', 'success', 'failed' ); create table payments.payment ( - payment_id uuid primary key default gen_random_uuid(), + payment_id uuid primary key default gen_random_uuid(), -- This cross-schema reference means migrations for the `accounts` crate should be run first. account_id uuid not null references accounts.account (account_id), - status payments.payment_status NOT NULL, + status payments.payment_status not null, -- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes) -- @@ -31,7 +32,7 @@ create table payments.payment -- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)` -- all use the same storage format in Postgres. Any constraint against the length of this field -- would purely be a sanity check. - currency text NOT NULL, + currency text not null, -- There's an endless debate about what type should be used to represent currency amounts. -- -- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly @@ -42,7 +43,7 @@ create table payments.payment -- -- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency, -- and so is what we've chosen here. - amount NUMERIC NOT NULL, + amount NUMERIC not null, -- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.), -- so imagine this is an identifier string for this payment in such a vendor's systems. @@ -50,8 +51,8 @@ create table payments.payment -- For privacy and security reasons, payment and personally-identifying information -- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor -- unless there is a good reason otherwise. - external_payment_id TEXT NOT NULL UNIQUE, - created_at timestamptz default now(), + external_payment_id text, + created_at timestamptz not null default now(), updated_at timestamptz ); diff --git a/examples/postgres/axum-multi-tenant/payments/sqlx.toml b/examples/postgres/multi-tenant/payments/sqlx.toml similarity index 100% rename from examples/postgres/axum-multi-tenant/payments/sqlx.toml rename to examples/postgres/multi-tenant/payments/sqlx.toml diff --git a/examples/postgres/multi-tenant/payments/src/lib.rs b/examples/postgres/multi-tenant/payments/src/lib.rs new file mode 100644 index 0000000000..6a1efe05ee --- /dev/null +++ b/examples/postgres/multi-tenant/payments/src/lib.rs @@ -0,0 +1,110 @@ +use accounts::AccountId; +use sqlx::{Acquire, PgConnection, Postgres}; +use time::OffsetDateTime; +use uuid::Uuid; + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(transparent)] +pub struct PaymentId(pub Uuid); + +#[derive(sqlx::Type, Copy, Clone, Debug)] +#[sqlx(type_name = "payments.payment_status")] +#[sqlx(rename_all = "snake_case")] +pub enum PaymentStatus { + Pending, + Created, + Success, + Failed, +} + +// Users often assume that they need `#[derive(FromRow)]` to use `query_as!()`, +// then are surprised when the derive's control attributes have no effect. +// The macros currently do *not* use the `FromRow` trait at all. +// Support for `FromRow` is planned, but would require significant changes to the macros. +// See https://github.com/launchbadge/sqlx/issues/514 for details. +#[derive(Clone, Debug)] +pub struct Payment { + pub payment_id: PaymentId, + pub account_id: AccountId, + pub status: PaymentStatus, + pub currency: String, + // `rust_decimal::Decimal` has more than enough precision for any real-world amount of money. + pub amount: rust_decimal::Decimal, + pub external_payment_id: Option, + pub created_at: OffsetDateTime, + pub updated_at: Option, +} + +// Accepting `impl Acquire` allows this function to be generic over `Pool`, `Connection` and `Transaction`. +pub async fn migrate(db: impl Acquire<'_, Database = Postgres>) -> sqlx::Result<()> { + sqlx::migrate!().run(db).await?; + Ok(()) +} + +pub async fn create( + conn: &mut PgConnection, + account_id: AccountId, + currency: &str, + amount: rust_decimal::Decimal, +) -> sqlx::Result { + // Imagine this method does more than just create a record in the database; + // maybe it actually initiates the payment with a third-party vendor, like Stripe. + // + // We need to ensure that we can link the payment in the vendor's systems back to a record + // in ours, even if any of the following happens: + // * The application dies before storing the external payment ID in the database + // * We lose the connection to the database while trying to commit a transaction + // * The database server dies while committing the transaction + // + // Thus, we create the payment in three atomic phases: + // * We create the payment record in our system and commit it. + // * We create the payment in the vendor's system with our payment ID attached. + // * We update our payment record with the vendor's payment ID. + let payment_id = sqlx::query_scalar!( + "insert into payments.payment(account_id, status, currency, amount) \ + values ($1, $2, $3, $4) \ + returning payment_id", + // The database doesn't give us enough information to correctly typecheck `AccountId` here. + // We have to insert the UUID directly. + account_id.0, + PaymentStatus::Pending, + currency, + amount, + ) + .fetch_one(&mut *conn) + .await?; + + // We then create the record with the payment vendor... + let external_payment_id = "foobar1234"; + + // Then we store the external payment ID and update the payment status. + // + // NOTE: use caution with `select *` or `returning *`; + // the order of columns gets baked into the binary, so if it changes between compile time and + // run-time, you may run into errors. + let payment = sqlx::query_as!( + Payment, + "update payments.payment \ + set status = $1, external_payment_id = $2 \ + where payment_id = $3 \ + returning *", + PaymentStatus::Created, + external_payment_id, + payment_id.0, + ) + .fetch_one(&mut *conn) + .await?; + + Ok(payment) +} + +pub async fn get(db: &mut PgConnection, payment_id: PaymentId) -> sqlx::Result> { + sqlx::query_as!( + Payment, + // see note above about `select *` + "select * from payments.payment where payment_id = $1", + payment_id.0 + ) + .fetch_optional(db) + .await +} diff --git a/examples/postgres/multi-tenant/sqlx.toml b/examples/postgres/multi-tenant/sqlx.toml new file mode 100644 index 0000000000..7a557cf4ba --- /dev/null +++ b/examples/postgres/multi-tenant/sqlx.toml @@ -0,0 +1,3 @@ +[migrate] +# Move `migrations/` to under `src/` to separate it from subcrates. +migrations-dir = "src/migrations" \ No newline at end of file diff --git a/examples/postgres/multi-tenant/src/main.rs b/examples/postgres/multi-tenant/src/main.rs new file mode 100644 index 0000000000..4aa1b9c5a8 --- /dev/null +++ b/examples/postgres/multi-tenant/src/main.rs @@ -0,0 +1,105 @@ +use accounts::AccountsManager; +use color_eyre::eyre; +use color_eyre::eyre::{Context, OptionExt}; +use rand::distributions::{Alphanumeric, DistString}; +use sqlx::Connection; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + color_eyre::install()?; + let _ = dotenvy::dotenv(); + tracing_subscriber::fmt::init(); + + let mut conn = sqlx::PgConnection::connect( + // `env::var()` doesn't include the variable name in the error. + &dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?, + ) + .await + .wrap_err("could not connect to database")?; + + // Runs migration for `accounts` internally. + let accounts = AccountsManager::setup(&mut conn, 1) + .await + .wrap_err("error initializing AccountsManager")?; + + payments::migrate(&mut conn) + .await + .wrap_err("error running payments migrations")?; + + // For simplicity's sake, imagine each of these might be invoked by different request routes + // in a web application. + + // POST /account + let user_email = format!("user{}@example.com", rand::random::()); + let user_password = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + // Requires an externally managed transaction in case any application-specific records + // should be created after the actual account record. + let mut txn = conn.begin().await?; + + let account_id = accounts + // Takes ownership of the password string because it's sent to another thread for hashing. + .create(&mut txn, &user_email, user_password.clone()) + .await + .wrap_err("error creating account")?; + + txn.commit().await?; + + println!("created account ID: {}, email: {user_email:?}, password: {user_password:?}", account_id.0); + + // POST /session + // Log the user in. + let session = accounts + .create_session(&mut conn, &user_email, user_password.clone()) + .await + .wrap_err("error creating session")?; + + // After this, session.session_token should then be returned to the client, + // either in the response body or a `Set-Cookie` header. + println!("created session token: {}", session.session_token.0); + + // POST /purchase + // The client would then pass the session token to authenticated routes. + // In this route, they're making some kind of purchase. + + // First, we need to ensure the session is valid. + // `session.session_token` would be passed by the client in whatever way is appropriate. + // + // For a pure REST API, consider an `Authorization: Bearer` header instead of the request body. + // With Axum, you can create a reusable extractor that reads the header and validates the session + // by implementing `FromRequestParts`. + // + // For APIs where the browser is intended to be the primary client, using a session cookie + // may be easier for the frontend. By setting the cookie with `HttpOnly: true`, + // it's impossible for malicious Javascript on the client to access and steal the session token. + let account_id = accounts + .auth_session(&mut conn, &session.session_token.0) + .await + .wrap_err("error authenticating session")? + .ok_or_eyre("session does not exist")?; + + let purchase_amount: rust_decimal::Decimal = "12.34".parse().unwrap(); + + // Then, because the user is making a purchase, we record a payment. + let payment = payments::create(&mut conn, account_id, "USD", purchase_amount) + .await + .wrap_err("error creating payment")?; + + println!("created payment: {payment:?}"); + + let purchase_id = sqlx::query_scalar!( + "insert into purchase(account_id, payment_id, amount) values ($1, $2, $3) returning purchase_id", + account_id.0, + payment.payment_id.0, + purchase_amount + ) + .fetch_one(&mut conn) + .await + .wrap_err("error creating purchase")?; + + println!("created purchase: {purchase_id}"); + + conn.close().await?; + + Ok(()) +} diff --git a/examples/postgres/multi-tenant/src/migrations/01_setup.sql b/examples/postgres/multi-tenant/src/migrations/01_setup.sql new file mode 100644 index 0000000000..0f275f7e89 --- /dev/null +++ b/examples/postgres/multi-tenant/src/migrations/01_setup.sql @@ -0,0 +1,30 @@ +-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging +-- and auditing. +-- +-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which +-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do +-- +-- select trigger_updated_at('
'); +-- +-- after a `CREATE TABLE`. +create or replace function set_updated_at() + returns trigger as +$$ +begin + NEW.updated_at = now(); + return NEW; +end; +$$ language plpgsql; + +create or replace function trigger_updated_at(tablename regclass) + returns void as +$$ +begin + execute format('CREATE TRIGGER set_updated_at + BEFORE UPDATE + ON %s + FOR EACH ROW + WHEN (OLD is distinct from NEW) + EXECUTE FUNCTION set_updated_at();', tablename); +end; +$$ language plpgsql; diff --git a/examples/postgres/multi-tenant/src/migrations/02_purchase.sql b/examples/postgres/multi-tenant/src/migrations/02_purchase.sql new file mode 100644 index 0000000000..3eebd64eb0 --- /dev/null +++ b/examples/postgres/multi-tenant/src/migrations/02_purchase.sql @@ -0,0 +1,11 @@ +create table purchase +( + purchase_id uuid primary key default gen_random_uuid(), + account_id uuid not null references accounts.account (account_id), + payment_id uuid not null references payments.payment (payment_id), + amount numeric not null, + created_at timestamptz not null default now(), + updated_at timestamptz +); + +select trigger_updated_at('purchase'); From 1b0c64a9e99938c91b4a939b1cf6fc3b06ef81a9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 17:00:37 -0800 Subject: [PATCH 23/30] chore(ci): test multi-tenant example --- .github/workflows/examples.yml | 28 ++++++++++++---- Cargo.lock | 32 +++++++++---------- Cargo.toml | 4 +-- examples/postgres/multi-tenant/Cargo.toml | 13 +++++--- .../postgres/multi-tenant/accounts/Cargo.toml | 2 +- .../postgres/multi-tenant/payments/Cargo.toml | 2 +- examples/postgres/multi-tenant/src/main.rs | 5 ++- 7 files changed, 54 insertions(+), 32 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 280d1fc4f3..4873a8c5cd 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -22,12 +22,12 @@ jobs: key: sqlx-cli - run: > - cargo build - -p sqlx-cli - --bin sqlx - --release - --no-default-features - --features mysql,postgres,sqlite + cargo build + -p sqlx-cli + --bin sqlx + --release + --no-default-features + --features mysql,postgres,sqlite - uses: actions/upload-artifact@v4 with: @@ -98,7 +98,7 @@ jobs: name: sqlx-cli path: /home/runner/.local/bin - - run: | + - run: | ls -R /home/runner/.local/bin chmod +x $HOME/.local/bin/sqlx echo $HOME/.local/bin >> $GITHUB_PATH @@ -175,6 +175,20 @@ jobs: DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos run: cargo run -p sqlx-example-postgres-mockable-todos + - name: Multi-Tenant (Setup) + working-directory: examples/postgres/multi-tenant + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos + run: | + (cd accounts && sqlx db setup) + (cd payments && sqlx migrate run) + sqlx migrate run + + - name: Mockable TODOs (Run) + env: + DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos + run: cargo run -p sqlx-example-postgres-mockable-todos + - name: TODOs (Setup) working-directory: examples/postgres/todos env: diff --git a/Cargo.lock b/Cargo.lock index 571a7dbc78..1f4674abb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -515,22 +515,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "axum-multi-tenant" -version = "0.8.3" -dependencies = [ - "accounts", - "axum 0.8.1", - "color-eyre", - "dotenvy", - "payments", - "rand", - "rust_decimal", - "sqlx", - "tokio", - "tracing-subscriber", -] - [[package]] name = "backoff" version = "0.4.0" @@ -3991,6 +3975,22 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx-example-postgres-multi-tenant" +version = "0.8.3" +dependencies = [ + "accounts", + "axum 0.8.1", + "color-eyre", + "dotenvy", + "payments", + "rand", + "rust_decimal", + "sqlx", + "tokio", + "tracing-subscriber", +] + [[package]] name = "sqlx-example-postgres-todos" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3b4d263834..8769b56ec3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,14 @@ members = [ "sqlx-postgres", "sqlx-sqlite", "examples/mysql/todos", - "examples/postgres/multi-tenant", "examples/postgres/axum-social-with-tests", "examples/postgres/chat", "examples/postgres/files", "examples/postgres/json", "examples/postgres/listen", - "examples/postgres/todos", "examples/postgres/mockable-todos", + "examples/postgres/multi-tenant", + "examples/postgres/todos", "examples/postgres/transaction", "examples/sqlite/todos", ] diff --git a/examples/postgres/multi-tenant/Cargo.toml b/examples/postgres/multi-tenant/Cargo.toml index f7dca28855..f93c91747a 100644 --- a/examples/postgres/multi-tenant/Cargo.toml +++ b/examples/postgres/multi-tenant/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "axum-multi-tenant" +name = "sqlx-example-postgres-multi-tenant" version.workspace = true license.workspace = true edition.workspace = true @@ -9,9 +9,6 @@ categories.workspace = true authors.workspace = true [dependencies] -accounts = { path = "accounts" } -payments = { path = "payments" } - tokio = { version = "1", features = ["rt-multi-thread", "macros"] } sqlx = { path = "../../..", version = "0.8.3", features = ["runtime-tokio", "postgres"] } @@ -26,5 +23,13 @@ rust_decimal = "1.36.0" rand = "0.8.5" +[dependencies.accounts] +package = "sqlx-example-postgres-multi-tenant-accounts" +path = "accounts" + +[dependencies.payments] +package = "sqlx-example-postgres-multi-tenant-accounts" +path = "payments" + [lints] workspace = true diff --git a/examples/postgres/multi-tenant/accounts/Cargo.toml b/examples/postgres/multi-tenant/accounts/Cargo.toml index 0295dcec8a..33b185912c 100644 --- a/examples/postgres/multi-tenant/accounts/Cargo.toml +++ b/examples/postgres/multi-tenant/accounts/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "accounts" +name = "sqlx-example-postgres-multi-tenant-accounts" version = "0.1.0" edition = "2021" diff --git a/examples/postgres/multi-tenant/payments/Cargo.toml b/examples/postgres/multi-tenant/payments/Cargo.toml index 6a0e4d2672..1c6d31868b 100644 --- a/examples/postgres/multi-tenant/payments/Cargo.toml +++ b/examples/postgres/multi-tenant/payments/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "payments" +name = "sqlx-example-postgres-multi-tenant-payments" version = "0.1.0" edition = "2021" diff --git a/examples/postgres/multi-tenant/src/main.rs b/examples/postgres/multi-tenant/src/main.rs index 4aa1b9c5a8..94a96fcf2b 100644 --- a/examples/postgres/multi-tenant/src/main.rs +++ b/examples/postgres/multi-tenant/src/main.rs @@ -45,7 +45,10 @@ async fn main() -> eyre::Result<()> { txn.commit().await?; - println!("created account ID: {}, email: {user_email:?}, password: {user_password:?}", account_id.0); + println!( + "created account ID: {}, email: {user_email:?}, password: {user_password:?}", + account_id.0 + ); // POST /session // Log the user in. From f4d22fb008962ef4914a95a4ed99453e5d9474b9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 17:10:23 -0800 Subject: [PATCH 24/30] fixup after merge --- Cargo.lock | 447 ++++++++++++++++-- examples/postgres/multi-tenant/Cargo.toml | 4 +- .../postgres/multi-tenant/payments/Cargo.toml | 5 +- sqlx-postgres/src/connection/executor.rs | 8 +- 4 files changed, 418 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 07754e7c22..ab37075bc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,18 +4,18 @@ version = 4 [[package]] name = "addr2line" -version = "0.24.2" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] [[package]] -name = "adler2" -version = "2.0.0" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" @@ -127,7 +127,19 @@ checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73" dependencies = [ "base64ct", "blake2", - "password-hash", + "password-hash 0.4.2", +] + +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash 0.5.0", ] [[package]] @@ -369,16 +381,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", - "axum-core", - "axum-macros", + "axum-core 0.2.9", + "axum-macros 0.2.3", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", "itoa", - "matchit", + "matchit 0.5.0", "memchr", "mime", "percent-encoding", @@ -386,14 +398,49 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "tokio", - "tower", + "tower 0.4.13", "tower-http", "tower-layer", "tower-service", ] +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", + "axum-macros 0.5.0", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.2.9" @@ -403,13 +450,33 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "mime", "tower-layer", "tower-service", ] +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-macros" version = "0.2.3" @@ -422,6 +489,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "backoff" version = "0.4.0" @@ -438,17 +516,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", + "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", - "windows-targets 0.52.6", ] [[package]] @@ -843,6 +921,33 @@ dependencies = [ "cc", ] +[[package]] +name = "color-eyre" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -1275,6 +1380,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -1527,9 +1642,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.1" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1656,6 +1771,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1663,7 +1789,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.2.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1701,8 +1850,8 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1714,6 +1863,41 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "hyper 1.6.0", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1898,6 +2082,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -2148,6 +2338,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -2187,11 +2383,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ - "adler2", + "adler", ] [[package]] @@ -2290,6 +2486,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -2355,9 +2561,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.7" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -2428,6 +2634,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking" version = "2.2.1" @@ -2468,6 +2686,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -3141,18 +3370,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -3180,6 +3409,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -3251,6 +3490,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -3463,6 +3711,7 @@ dependencies = [ "time", "tokio", "tokio-stream", + "toml", "tracing", "url", "uuid", @@ -3485,8 +3734,8 @@ name = "sqlx-example-postgres-axum-social" version = "0.1.0" dependencies = [ "anyhow", - "argon2", - "axum", + "argon2 0.4.1", + "axum 0.5.17", "dotenvy", "once_cell", "rand", @@ -3498,7 +3747,7 @@ dependencies = [ "thiserror 2.0.11", "time", "tokio", - "tower", + "tower 0.4.13", "tracing", "uuid", "validator", @@ -3563,6 +3812,48 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx-example-postgres-multi-tenant" +version = "0.8.3" +dependencies = [ + "axum 0.8.1", + "color-eyre", + "dotenvy", + "rand", + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-tenant-accounts", + "sqlx-example-postgres-multi-tenant-payments", + "tokio", + "tracing-subscriber", +] + +[[package]] +name = "sqlx-example-postgres-multi-tenant-accounts" +version = "0.1.0" +dependencies = [ + "argon2 0.5.3", + "password-hash 0.5.0", + "rand", + "serde", + "sqlx", + "thiserror 1.0.69", + "time", + "tokio", + "uuid", +] + +[[package]] +name = "sqlx-example-postgres-multi-tenant-payments" +version = "0.1.0" +dependencies = [ + "rust_decimal", + "sqlx", + "sqlx-example-postgres-multi-tenant-accounts", + "time", + "uuid", +] + [[package]] name = "sqlx-example-postgres-todos" version = "0.1.0" @@ -3932,6 +4223,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.1" @@ -4024,6 +4321,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.37" @@ -4180,6 +4487,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.3.5" @@ -4190,11 +4513,11 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "http-range-header", "pin-project-lite", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", ] @@ -4241,6 +4564,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-error" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" +dependencies = [ + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -4369,9 +4728,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.1" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" +checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" dependencies = [ "serde", ] @@ -4418,6 +4777,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "value-bag" version = "1.10.0" diff --git a/examples/postgres/multi-tenant/Cargo.toml b/examples/postgres/multi-tenant/Cargo.toml index f93c91747a..200fcfd2e8 100644 --- a/examples/postgres/multi-tenant/Cargo.toml +++ b/examples/postgres/multi-tenant/Cargo.toml @@ -24,12 +24,12 @@ rust_decimal = "1.36.0" rand = "0.8.5" [dependencies.accounts] -package = "sqlx-example-postgres-multi-tenant-accounts" path = "accounts" +package = "sqlx-example-postgres-multi-tenant-accounts" [dependencies.payments] -package = "sqlx-example-postgres-multi-tenant-accounts" path = "payments" +package = "sqlx-example-postgres-multi-tenant-payments" [lints] workspace = true diff --git a/examples/postgres/multi-tenant/payments/Cargo.toml b/examples/postgres/multi-tenant/payments/Cargo.toml index 1c6d31868b..1f7d7c3f75 100644 --- a/examples/postgres/multi-tenant/payments/Cargo.toml +++ b/examples/postgres/multi-tenant/payments/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -accounts = { path = "../accounts" } sqlx = { workspace = true, features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml"] } @@ -13,5 +12,9 @@ rust_decimal = "1.36.0" time = "0.3.37" uuid = "1.12.1" +[dependencies.accounts] +path = "../accounts" +package = "sqlx-example-postgres-multi-tenant-accounts" + [dev-dependencies] sqlx = { workspace = true, features = ["runtime-tokio"] } diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 28e5e72ed0..328d8fd343 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -80,7 +80,9 @@ async fn prepare( let parameters = conn.handle_parameter_description(parameters).await?; - let (columns, column_names) = conn.handle_row_description(rows, true, fetch_column_origin).await?; + let (columns, column_names) = conn + .handle_row_description(rows, true, fetch_column_origin) + .await?; // ensure that if we did fetch custom data, we wait until we are fully ready before // continuing @@ -449,7 +451,9 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None, true).await?; + let (_, metadata) = self + .get_or_prepare(sql, parameters, true, None, true) + .await?; Ok(PgStatement { sql: Cow::Borrowed(sql), From 15df1593c6ab44979533a4bf6b52bd53e0a2d06e Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 17:22:39 -0800 Subject: [PATCH 25/30] fix(ci): enable `sqlx-toml` in CLI build for examples --- .github/workflows/examples.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 4873a8c5cd..6fbb4b31eb 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -27,7 +27,7 @@ jobs: --bin sqlx --release --no-default-features - --features mysql,postgres,sqlite + --features mysql,postgres,sqlite,sqlx-toml - uses: actions/upload-artifact@v4 with: From 4fb7102cb2cddc66d2ad39eaac238e86a25d0345 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 17:27:57 -0800 Subject: [PATCH 26/30] fix: CI, README for `multi-tenant` --- .github/workflows/examples.yml | 8 ++++---- examples/postgres/multi-tenant/README.md | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 6fbb4b31eb..8d25e96f81 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -178,16 +178,16 @@ jobs: - name: Multi-Tenant (Setup) working-directory: examples/postgres/multi-tenant env: - DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant run: | (cd accounts && sqlx db setup) (cd payments && sqlx migrate run) sqlx migrate run - - name: Mockable TODOs (Run) + - name: Multi-Tenant (Run) env: - DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos - run: cargo run -p sqlx-example-postgres-mockable-todos + DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant + run: cargo run -p sqlx-example-postgres-multi-tenant - name: TODOs (Setup) working-directory: examples/postgres/todos diff --git a/examples/postgres/multi-tenant/README.md b/examples/postgres/multi-tenant/README.md index 9f96ff72f1..8122d852a7 100644 --- a/examples/postgres/multi-tenant/README.md +++ b/examples/postgres/multi-tenant/README.md @@ -19,7 +19,7 @@ This example uses schema-qualified names everywhere for clarity. It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema prefixes, but this can cause some really confusing issues when names conflict. -This example will generate a `_sqlx_migrations` table in three different schemas, and if `search_path` is set +This example will generate a `_sqlx_migrations` table in three different schemas; if `search_path` is set to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified, it would throw an error. @@ -27,11 +27,23 @@ it would throw an error. This example requires running three different sets of migrations. -Ensure `sqlx-cli` is installed with Postgres support. +Ensure `sqlx-cli` is installed with Postgres and `sqlx.toml` support: -Start a Postgres server. +``` +cargo install sqlx-cli --features postgres,sqlx-toml +``` + +Start a Postgres server (shown here using Docker, `run` command also works with `podman`): -Create `.env` with `DATABASE_URL` or set it in your shell environment. +``` +docker run -d -e POSTGRES_PASSWORD=password -p 5432:5432 --name postgres postgres:latest +``` + +Create `.env` with `DATABASE_URL` or set the variable in your shell environment; + +``` +DATABASE_URL=postgres://postgres:password@localhost/example-multi-tenant +``` Run the following commands: From 2b6915000c4b6e10c9774c6a5c6fb79efd6786b7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 18:16:09 -0800 Subject: [PATCH 27/30] fix: clippy warnings --- sqlx-postgres/src/migrate.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index 90ebd49a73..8275bda188 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -154,7 +154,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as(&*format!( + let row: Option<(i64,)> = query_as(&format!( "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" )) .fetch_optional(self) @@ -170,7 +170,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = query_as(&*format!( + let rows: Vec<(i64, Vec)> = query_as(&format!( "SELECT version, checksum FROM {table_name} ORDER BY version" )) .fetch_all(self) @@ -253,7 +253,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query(&*format!( + let _ = query(&format!( r#" UPDATE {table_name} SET execution_time = $1 @@ -306,7 +306,7 @@ async fn execute_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(&*format!( + let _ = query(&format!( r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) @@ -332,7 +332,7 @@ async fn revert_migration( .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(&*format!(r#"DELETE FROM {table_name} WHERE version = $1"#)) + let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = $1"#)) .bind(migration.version) .execute(conn) .await?; From a9a4d00a800a370adcb3570e176625535c1f430d Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 18:17:40 -0800 Subject: [PATCH 28/30] fix: multi-tenant README --- examples/postgres/multi-tenant/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/postgres/multi-tenant/README.md b/examples/postgres/multi-tenant/README.md index 8122d852a7..3688202690 100644 --- a/examples/postgres/multi-tenant/README.md +++ b/examples/postgres/multi-tenant/README.md @@ -3,7 +3,7 @@ This example project involves three crates, each owning a different schema in one database, with their own set of migrations. -* The main crate, an Axum app. +* The main crate, a simple binary simulating the action of a REST API. * Owns the `public` schema (tables are referenced unqualified). * Migrations are moved to `src/migrations` using config key `migrate.migrations-dir` to visually separate them from the subcrate folders. From ac7f27093244cd81c83e7afd3cd87d6f21989d59 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 18:30:43 -0800 Subject: [PATCH 29/30] fix: sequential versioning inference for migrations --- sqlx-cli/src/opt.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 3230148e8b..afc4548888 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -452,8 +452,8 @@ fn next_timestamp() -> String { fn next_sequential(migrator: &Migrator) -> Option { let next_version = migrator .migrations - .windows(2) - .last() + .rchunks(2) + .next() .and_then(|migrations| { match migrations { [previous, latest] => { From 8ddcd0640e831878228925aa5cfa93a49eb7ff24 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 27 Feb 2025 18:36:50 -0800 Subject: [PATCH 30/30] fix: migration versioning with explicit overrides --- sqlx-cli/src/opt.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index afc4548888..a75fe002b3 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -433,15 +433,16 @@ impl AddMigrationOpts { pub fn version_prefix(&self, config: &Config, migrator: &Migrator) -> String { let default_versioning = &config.migrate.defaults.migration_versioning; - if self.timestamp || matches!(default_versioning, DefaultVersioning::Timestamp) { - return next_timestamp(); - } - - if self.sequential || matches!(default_versioning, DefaultVersioning::Sequential) { - return next_sequential(migrator).unwrap_or_else(|| fmt_sequential(1)); + match (self.timestamp, self.sequential, default_versioning) { + (true, false, _) | (false, false, DefaultVersioning::Timestamp) => next_timestamp(), + (false, true, _) | (false, false, DefaultVersioning::Sequential) => { + next_sequential(migrator).unwrap_or_else(|| fmt_sequential(1)) + } + (false, false, DefaultVersioning::Inferred) => { + next_sequential(migrator).unwrap_or_else(next_timestamp) + } + (true, true, _) => unreachable!("BUG: Clap should have rejected this case"), } - - next_sequential(migrator).unwrap_or_else(next_timestamp) } }