From a350170f22a273203a5565ef08386747891f3fcb Mon Sep 17 00:00:00 2001 From: pfeme Date: Sun, 28 Jan 2024 21:52:56 +0100 Subject: [PATCH] custom exceptions --- .../src/main/scala/zio/jdbc/JdbcDecoder.scala | 13 +- .../scala/zio/jdbc/JdbcDecoderError.scala | 27 ---- .../scala/zio/jdbc/JdbcEncoderError.scala | 5 - .../main/scala/zio/jdbc/JdbcException.scala | 86 ++++++++++++ core/src/main/scala/zio/jdbc/Query.scala | 72 +++++----- .../src/main/scala/zio/jdbc/SqlFragment.scala | 60 +++++---- .../src/main/scala/zio/jdbc/ZConnection.scala | 60 +++++---- .../main/scala/zio/jdbc/ZConnectionPool.scala | 127 ++++++------------ .../zio/jdbc/ZConnectionPoolConfig.scala | 2 +- core/src/main/scala/zio/jdbc/ZResultSet.scala | 3 +- core/src/main/scala/zio/jdbc/package.scala | 2 +- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 5 +- .../test/scala/zio/jdbc/TestConnection.scala | 17 +-- .../scala/zio/jdbc/ZConnectionPoolSpec.scala | 2 +- .../test/scala/zio/jdbc/ZConnectionSpec.scala | 16 +-- 15 files changed, 273 insertions(+), 224 deletions(-) delete mode 100644 core/src/main/scala/zio/jdbc/JdbcDecoderError.scala delete mode 100644 core/src/main/scala/zio/jdbc/JdbcEncoderError.scala create mode 100644 core/src/main/scala/zio/jdbc/JdbcException.scala diff --git a/core/src/main/scala/zio/jdbc/JdbcDecoder.scala b/core/src/main/scala/zio/jdbc/JdbcDecoder.scala index 89c5ad86..dedf70ec 100644 --- a/core/src/main/scala/zio/jdbc/JdbcDecoder.scala +++ b/core/src/main/scala/zio/jdbc/JdbcDecoder.scala @@ -28,9 +28,10 @@ import scala.collection.immutable.ListMap trait JdbcDecoder[+A] { self => def unsafeDecode(columIndex: Int, rs: ResultSet): (Int, A) - final def decode(columnIndex: Int, rs: ResultSet): Either[Throwable, (Int, A)] = - try Right(unsafeDecode(columnIndex, rs)) - catch { case e: JdbcDecoderError => Left(e) } + final def decode(columnIndex: Int, rs: ResultSet): IO[JdbcDecoderError, (Int, A)] = + ZIO.attempt(unsafeDecode(columnIndex, rs)).refineOrDie { case e => + JdbcDecoderError(e.getMessage(), e, rs.getMetaData(), rs.getRow()) + } final def map[B](f: A => B): JdbcDecoder[B] = new JdbcDecoder[B] { @@ -137,9 +138,9 @@ object JdbcDecoder extends JdbcDecoderLowPriorityImplicits { implicit def optionDecoder[A](implicit decoder: JdbcDecoder[A]): JdbcDecoder[Option[A]] = JdbcDecoder(rs => int => - decoder.decode(int, rs) match { - case Left(_) => None - case Right(value) => Option(value._2) + try Some(decoder.unsafeDecode(int, rs)._2) + catch { + case _: Throwable => None } ) diff --git a/core/src/main/scala/zio/jdbc/JdbcDecoderError.scala b/core/src/main/scala/zio/jdbc/JdbcDecoderError.scala deleted file mode 100644 index f65a53e9..00000000 --- a/core/src/main/scala/zio/jdbc/JdbcDecoderError.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright 2022 John A. De Goes and the ZIO Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package zio.jdbc - -import java.io.IOException -import java.sql.ResultSetMetaData - -final case class JdbcDecoderError( - message: String, - cause: Throwable, - metadata: ResultSetMetaData, - row: Int, - column: Option[Int] = None -) extends IOException(message, cause) diff --git a/core/src/main/scala/zio/jdbc/JdbcEncoderError.scala b/core/src/main/scala/zio/jdbc/JdbcEncoderError.scala deleted file mode 100644 index 4b08f3dd..00000000 --- a/core/src/main/scala/zio/jdbc/JdbcEncoderError.scala +++ /dev/null @@ -1,5 +0,0 @@ -package zio.jdbc - -import java.io.IOException - -final case class JdbcEncoderError(message: String, cause: Throwable) extends IOException(message, cause) diff --git a/core/src/main/scala/zio/jdbc/JdbcException.scala b/core/src/main/scala/zio/jdbc/JdbcException.scala new file mode 100644 index 00000000..a0dffa05 --- /dev/null +++ b/core/src/main/scala/zio/jdbc/JdbcException.scala @@ -0,0 +1,86 @@ +/* + * Copyright 2022 John A. De Goes and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package zio.jdbc + +import java.io.IOException +import java.sql.{ ResultSetMetaData, SQLException, SQLTimeoutException } + +/** + * Trait to encapsule all the exceptions employed by ZIO JDBC + */ +sealed trait JdbcException extends Exception + +/** + * Exceptions subtraits to specify the type of error + */ +sealed trait ConnectionException extends JdbcException +sealed trait QueryException extends JdbcException +sealed trait CodecException extends JdbcException with QueryException + +sealed trait FatalException extends JdbcException +sealed trait RecoverableException extends JdbcException + +// ZTimeoutException groups all the errors produced by SQLTimeoutException. +sealed trait ZTimeoutException extends JdbcException + +/** + * ConnectionException. Related to the connection operations with a database + */ +final case class DriverNotFound(cause: Throwable, driver: String) + extends Exception(s"Could not found driver: $driver", cause) + with ConnectionException + with FatalException +final case class DBError(cause: Throwable) extends Exception(cause) with ConnectionException with FatalException +final case class FailedMakingRestorable(cause: Throwable) + extends Exception(cause) + with ConnectionException + with FatalException +final case class ConnectionTimeout(cause: Throwable) + extends Exception(cause) + with ConnectionException + with ZTimeoutException + with RecoverableException + +/** + * CodecExceptions. Related to the decoding and encoding of the data in a transaction + */ +final case class DecodeException(cause: Throwable) extends Exception(cause) with CodecException with FatalException +final case class JdbcDecoderError( + message: String, + cause: Throwable, + metadata: ResultSetMetaData, + row: Int, + column: Option[Int] = None +) extends IOException(message, cause) + with CodecException + with FatalException +final case class JdbcEncoderError(message: String, cause: Throwable) + extends IOException(message, cause) + with CodecException + with FatalException + +/** + * FailedQueries. Related to the failure of actions executed directly on a database + */ +final case class ZSQLException(cause: SQLException) + extends Exception(cause) + with QueryException + with RecoverableException +final case class ZSQLTimeoutException(cause: SQLTimeoutException) + extends Exception(cause) + with QueryException + with ZTimeoutException + with RecoverableException diff --git a/core/src/main/scala/zio/jdbc/Query.scala b/core/src/main/scala/zio/jdbc/Query.scala index 442c50c8..61554bdb 100644 --- a/core/src/main/scala/zio/jdbc/Query.scala +++ b/core/src/main/scala/zio/jdbc/Query.scala @@ -18,62 +18,59 @@ package zio.jdbc import zio._ import zio.stream._ -final case class Query[+A](sql: SqlFragment, decode: ZResultSet => A) { +import java.sql.{ SQLException, SQLTimeoutException } + +final case class Query[+A](decode: ZResultSet => IO[CodecException, A], sql: SqlFragment) { def as[B](implicit decoder: JdbcDecoder[B]): Query[B] = - Query(sql, zrs => decoder.unsafeDecode(1, zrs.resultSet)._2) + Query(zrs => decoder.decode(1, zrs.resultSet).map(_._2), sql) def map[B](f: A => B): Query[B] = - Query(sql, zrs => f(decode(zrs))) + Query(zrs => decode(zrs).map(f), sql) /** * Performs a SQL select query, returning all results in a chunk. */ - def selectAll: ZIO[ZConnection, Throwable, Chunk[A]] = + def selectAll: ZIO[ZConnection, QueryException, Chunk[A]] = ZIO.scoped(for { zrs <- executeQuery(sql) - chunk <- ZIO.attempt { - val builder = ChunkBuilder.make[A]() - while (zrs.next()) - builder += decode(zrs) - builder.result() + chunk <- ZIO.iterate(ChunkBuilder.make[A]())(_ => zrs.next()) { builder => + for { + decoded <- decode(zrs) + } yield builder += decoded } - } yield chunk) + } yield chunk.result()) /** * Performs a SQL select query, returning the first result, if any. */ - def selectOne: ZIO[ZConnection, Throwable, Option[A]] = + def selectOne: ZIO[ZConnection, QueryException, Option[A]] = ZIO.scoped(for { zrs <- executeQuery(sql) - option <- ZIO.attempt { - if (zrs.next()) Some(decode(zrs)) else None - } + option <- + if (zrs.next()) decode(zrs).map(Some(_)) + else ZIO.none } yield option) /** * Performs a SQL select query, returning a stream of results. */ - def selectStream(chunkSize: => Int = ZStream.DefaultChunkSize): ZStream[ZConnection, Throwable, A] = + def selectStream(chunkSize: => Int = ZStream.DefaultChunkSize): ZStream[ZConnection, QueryException, A] = ZStream.unwrapScoped { for { zrs <- executeQuery(sql) stream = ZStream.paginateChunkZIO(())(_ => - ZIO.attemptBlocking { - val builder = ChunkBuilder.make[A](chunkSize) - var hasNext = false - var i = 0 - while ( - i < chunkSize && { - hasNext = zrs.next() - hasNext - } - ) { - builder.addOne(decode(zrs)) - i += 1 + ZIO + .iterate((ChunkBuilder.make[A](chunkSize), 0)) { case (_, i) => + i < chunkSize && zrs.next() + } { case (builder, i) => + for { + decoded <- decode(zrs) + } yield (builder += decoded, i + 1) + } + .map { case (builder, i) => + (builder.result(), if (i >= chunkSize) Some(()) else None) } - (builder.result(), if (hasNext) Some(()) else None) - } ) } yield stream } @@ -81,11 +78,14 @@ final case class Query[+A](sql: SqlFragment, decode: ZResultSet => A) { def withDecode[B](f: ZResultSet => B): Query[B] = Query(sql, f) - private[jdbc] def executeQuery(sql: SqlFragment): ZIO[Scope with ZConnection, Throwable, ZResultSet] = for { + private[jdbc] def executeQuery(sql: SqlFragment): ZIO[Scope with ZConnection, QueryException, ZResultSet] = for { connection <- ZIO.service[ZConnection] zrs <- connection.executeSqlWith(sql, false) { ps => ZIO.acquireRelease { - ZIO.attempt(ZResultSet(ps.executeQuery())) + ZIO.attempt(ZResultSet(ps.executeQuery())).refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) + } }(_.close) } } yield zrs @@ -94,7 +94,15 @@ final case class Query[+A](sql: SqlFragment, decode: ZResultSet => A) { object Query { + def apply[A](sql: SqlFragment, decode: ZResultSet => A): Query[A] = { + def decodeZIO(zrs: ZResultSet): IO[DecodeException, A] = + ZIO.attempt(decode(zrs)).refineOrDie { case e: Throwable => + DecodeException(e) + } + new Query[A](zrs => decodeZIO(zrs), sql) + } + def fromSqlFragment[A](sql: SqlFragment)(implicit decoder: JdbcDecoder[A]): Query[A] = - Query[A](sql, zrs => decoder.unsafeDecode(1, zrs.resultSet)._2) + Query[A](sql, (zrs: ZResultSet) => decoder.unsafeDecode(1, zrs.resultSet)._2) } diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index ac6e4602..4861d09c 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -18,7 +18,7 @@ package zio.jdbc import zio._ import zio.jdbc.SqlFragment.Segment -import java.sql.{ PreparedStatement, Types } +import java.sql.{ PreparedStatement, SQLException, SQLTimeoutException, Types } import java.time.{ OffsetDateTime, ZoneOffset } import scala.language.implicitConversions @@ -172,18 +172,22 @@ sealed trait SqlFragment { self => /** * Executes a SQL statement, such as one that creates a table. */ - def execute: ZIO[ZConnection, Throwable, Unit] = - ZIO.scoped(for { - connection <- ZIO.service[ZConnection] - _ <- connection.executeSqlWith(self, false) { ps => - ZIO.attempt(ps.executeUpdate()) - } - } yield ()) + def execute: ZIO[ZConnection, QueryException, Unit] = + ZIO + .scoped(for { + connection <- ZIO.service[ZConnection] + _ <- connection.executeSqlWith(self, false) { ps => + ZIO.attempt(ps.executeUpdate()).refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) + } + } + } yield ()) /** * Executes a SQL delete query. */ - def delete: ZIO[ZConnection, Throwable, Long] = + def delete: ZIO[ZConnection, QueryException, Long] = ZIO.scoped(executeLargeUpdate(self)) /** @@ -212,33 +216,36 @@ sealed trait SqlFragment { self => * parsed and returned as `Chunk[Long]`. If keys are non-numeric, a * `Chunk.empty` is returned. */ - def insertWithKeys: ZIO[ZConnection, Throwable, UpdateResult[Long]] = + def insertWithKeys: ZIO[ZConnection, QueryException, UpdateResult[Long]] = ZIO.scoped(executeWithReturning(self, JdbcDecoder[Long])) /** * Executes a SQL update query with a RETURNING clause, materialized * as values of type `A`. */ - def updateReturning[A: JdbcDecoder]: ZIO[ZConnection, Throwable, UpdateResult[A]] = + def updateReturning[A: JdbcDecoder]: ZIO[ZConnection, QueryException, UpdateResult[A]] = ZIO.scoped(executeWithReturning(self, JdbcDecoder[A])) /** * Performs a SQL update query, returning a count of rows updated. */ - def update: ZIO[ZConnection, Throwable, Long] = + def update: ZIO[ZConnection, QueryException, Long] = ZIO.scoped(executeLargeUpdate(self)) - private def executeLargeUpdate(sql: SqlFragment): ZIO[Scope with ZConnection, Throwable, Long] = for { + private def executeLargeUpdate(sql: SqlFragment): ZIO[Scope with ZConnection, QueryException, Long] = for { connection <- ZIO.service[ZConnection] count <- connection.executeSqlWith(sql, false) { ps => - ZIO.attempt(ps.executeLargeUpdate()) + ZIO.attempt(ps.executeLargeUpdate()).refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) + } } } yield count private def executeWithReturning[A]( sql: SqlFragment, decoder: JdbcDecoder[A] - ): ZIO[Scope with ZConnection, Throwable, UpdateResult[A]] = + ): ZIO[Scope with ZConnection, QueryException, UpdateResult[A]] = for { updateRes <- executeUpdate(sql, true) (count, maybeRs) = updateRes @@ -250,6 +257,9 @@ sealed trait SqlFragment { self => while (rs.next()) builder += decoder.unsafeDecode(1, rs.resultSet)._2 builder.result() + }.refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) } } } yield UpdateResult(count, keys) @@ -257,17 +267,21 @@ sealed trait SqlFragment { self => private[jdbc] def executeUpdate( sql: SqlFragment, returnAutoGeneratedKeys: Boolean - ): ZIO[Scope with ZConnection, Throwable, (Long, Option[ZResultSet])] = + ): ZIO[Scope with ZConnection, QueryException, (Long, Option[ZResultSet])] = for { connection <- ZIO.service[ZConnection] result <- connection.executeSqlWith(sql, returnAutoGeneratedKeys) { ps => - ZIO.acquireRelease(ZIO.attempt { - val rowsUpdated = ps.executeLargeUpdate() - val updatedKeys = if (returnAutoGeneratedKeys) Some(ps.getGeneratedKeys) else None - (rowsUpdated, updatedKeys.map(ZResultSet(_))) - - })(_._2.map(_.close).getOrElse(ZIO.unit)) - + ZIO + .acquireRelease(ZIO.attempt { + val rowsUpdated = ps.executeLargeUpdate() + val updatedKeys = if (returnAutoGeneratedKeys) Some(ps.getGeneratedKeys) else None + (rowsUpdated, updatedKeys.map(ZResultSet(_))) + + })(_._2.map(_.close).getOrElse(ZIO.unit)) + .refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) + } } } yield result diff --git a/core/src/main/scala/zio/jdbc/ZConnection.scala b/core/src/main/scala/zio/jdbc/ZConnection.scala index f740f65d..6da64434 100644 --- a/core/src/main/scala/zio/jdbc/ZConnection.scala +++ b/core/src/main/scala/zio/jdbc/ZConnection.scala @@ -17,7 +17,7 @@ package zio.jdbc import zio._ -import java.sql.{ Connection, PreparedStatement, Statement } +import java.sql.{ Connection, PreparedStatement, SQLException, SQLTimeoutException, Statement } /** * A `ZConnection` is a straightforward wrapper around `java.sql.Connection`. In order @@ -33,13 +33,17 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex def accessZIO[A](f: Connection => ZIO[Scope, Throwable, A]): ZIO[Scope, Throwable, A] = ZIO.blocking(f(underlying)) - def close: Task[Any] = access(_.close()) - def rollback: Task[Any] = access(_.rollback()) + def close: ZIO[Scope, ZSQLException, Any] = access(_.close()).refineOrDie { case e: SQLException => + ZSQLException(e) + } + def rollback: ZIO[Scope, ZSQLException, Any] = access(_.rollback()).refineOrDie { case e: SQLException => + ZSQLException(e) + } private[jdbc] def executeSqlWith[A]( sql: SqlFragment, returnAutoGeneratedKeys: Boolean - )(f: PreparedStatement => ZIO[Scope, Throwable, A]): ZIO[Scope, Throwable, A] = + )(f: PreparedStatement => ZIO[Scope, QueryException, A]): ZIO[Scope, QueryException, A] = accessZIO { connection => for { transactionIsolationLevel <- currentTransactionIsolationLevel.get @@ -77,6 +81,9 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex .flatMap(_.join) .onInterrupt(ZIO.attemptBlocking(statement.cancel()).ignoreLogged) } yield result + }.refineOrDie { + case e: SQLTimeoutException => ZSQLTimeoutException(e) + case e: SQLException => ZSQLException(e) }.tapErrorCause { cause => // TODO: Question: do we want logging here, switch to debug for now ZIO.logAnnotate("SQL", sql.toString)( ZIO.logDebugCause(s"Error executing SQL due to: ${cause.prettyPrint}", cause) @@ -93,12 +100,15 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex * @param zc the connection to look into * @return true if the connection is alive (valid), false otherwise */ - def isValid(): Task[Boolean] = + def isValid(): ZIO[Scope, ZSQLException, Boolean] = { for { - closed <- ZIO.attempt(this.underlying.isClosed) - statement <- ZIO.attempt(this.underlying.prepareStatement("SELECT 1")) + closed <- access(_.isClosed) + statement <- access(_.prepareStatement("SELECT 1")) isAlive <- ZIO.succeed(!closed && statement != null) } yield isAlive + }.refineOrDie { case e: SQLException => + ZSQLException(e) + } /** * Returns whether the connection is still alive or not, providing a timeout, @@ -109,8 +119,10 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex * @param zc the connection to look into * @return true if the connection is alive (valid), false otherwise */ - def isValid(timeout: Int): Task[Boolean] = - ZIO.attempt(this.underlying.isValid(timeout)) + def isValid(timeout: Int): ZIO[Scope, ZSQLException, Boolean] = + access(_.isValid(timeout)).refineOrDie { case e: SQLException => + ZSQLException(e) + } private[jdbc] def restore: UIO[Unit] = ZIO.succeed(underlying.restore()) @@ -118,9 +130,11 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex object ZConnection { - def make(underlying: Connection): Task[ZConnection] = + def make(underlying: Connection): IO[ConnectionException, ZConnection] = for { - restorable <- ZIO.attempt(new Restorable(underlying)) + restorable <- ZIO.attempt(new Restorable(underlying)).refineOrDie { case e: SQLException => + FailedMakingRestorable(e) + } } yield new ZConnection(restorable) private[jdbc] class Restorable(underlying: Connection) extends Connection { @@ -311,33 +325,33 @@ object ZConnection { object Flag { case object AutoCommit extends Flag { - val index = 1 - val mask = 1 << index + val index = 1 + val mask: Int = 1 << index } case object Catalog extends Flag { - val index = 2 - val mask = 1 << index + val index = 2 + val mask: Int = 1 << index } case object ClientInfo extends Flag { - val index = 3 - val mask = 1 << index + val index = 3 + val mask: Int = 1 << index } case object ReadOnly extends Flag { - val index = 4 - val mask = 1 << index + val index = 4 + val mask: Int = 1 << index } case object Schema extends Flag { - val index = 5 - val mask = 1 << index + val index = 5 + val mask: Int = 1 << index } case object TransactionIsolation extends Flag { - val index = 6 - val mask = 1 << index + val index = 6 + val mask: Int = 1 << index } } } diff --git a/core/src/main/scala/zio/jdbc/ZConnectionPool.scala b/core/src/main/scala/zio/jdbc/ZConnectionPool.scala index 9b8fb7e4..381b8228 100644 --- a/core/src/main/scala/zio/jdbc/ZConnectionPool.scala +++ b/core/src/main/scala/zio/jdbc/ZConnectionPool.scala @@ -18,146 +18,103 @@ package zio.jdbc import zio._ import java.io.File -import java.sql.Connection +import java.lang.ClassNotFoundException +import java.sql.{ Connection, SQLException, SQLTimeoutException } /** * A `ZConnectionPool` represents a pool of connections, and has the ability to * supply a transaction that can be used for executing SQL statements. */ abstract class ZConnectionPool { - def transaction: ZLayer[Any, Throwable, ZConnection] + def transaction: ZLayer[Any, ConnectionException, ZConnection] def invalidate(conn: ZConnection): UIO[Any] } object ZConnectionPool { - def h2test: ZLayer[Any, Throwable, ZConnectionPool] = + + def h2test: ZLayer[Any, ConnectionException, ZConnectionPool] = ZLayer.scoped { for { - _ <- ZIO.attempt(Class.forName("org.h2.Driver")) - int <- Random.nextInt - acquire = ZIO.attemptBlocking { - java.sql.DriverManager.getConnection(s"jdbc:h2:mem:test_database_$int") - } - zenv <- make(acquire).build.provideSome[Scope](ZLayer.succeed(ZConnectionPoolConfig.default)) + int <- Random.nextInt + zenv <- connect("org.h2.Driver", s"jdbc:h2:mem:test_database_$int", Map.empty).build } yield zenv.get[ZConnectionPool] } def h2mem( database: String, props: Map[String, String] = Map() - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = - ZLayer.scoped { - for { - _ <- ZIO.attempt(Class.forName("org.h2.Driver")) - acquire = ZIO.attemptBlocking { - val properties = new java.util.Properties - props.foreach { case (k, v) => properties.setProperty(k, v) } - - java.sql.DriverManager.getConnection(s"jdbc:h2:mem:$database", properties) - } - zenv <- make(acquire).build - } yield zenv.get[ZConnectionPool] - } + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect("org.h2.Driver", s"jdbc:h2:mem:$database", props) def h2file( directory: File, database: String, props: Map[String, String] = Map() - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = - ZLayer.scoped { - for { - _ <- ZIO.attempt(Class.forName("org.h2.Driver")) - acquire = ZIO.attemptBlocking { - val properties = new java.util.Properties - props.foreach { case (k, v) => properties.setProperty(k, v) } - - java.sql.DriverManager.getConnection(s"jdbc:h2:file:$directory/$database", properties) - } - zenv <- make(acquire).build - } yield zenv.get[ZConnectionPool] - } + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect("org.h2.Driver", s"jdbc:h2:file:$directory/$database", props) def oracle( host: String, port: Int, database: String, props: Map[String, String] - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = - ZLayer.scoped { - for { - _ <- ZIO.attempt(Class.forName("oracle.jdbc.OracleDriver")) - acquire = ZIO.attemptBlocking { - val properties = new java.util.Properties - props.foreach { case (k, v) => properties.setProperty(k, v) } - - java.sql.DriverManager.getConnection(s"jdbc:oracle:thin:@$host:$port:$database", properties) - } - zenv <- make(acquire).build - } yield zenv.get[ZConnectionPool] - } + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect("oracle.jdbc.OracleDriver", s"jdbc:oracle:thin:@$host:$port:$database", props) def postgres( host: String, port: Int, database: String, props: Map[String, String] - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = - ZLayer.scoped { - for { - _ <- ZIO.attempt(Class.forName("org.postgresql.Driver")) - acquire = ZIO.attemptBlocking { - val properties = new java.util.Properties - - props.foreach { case (k, v) => properties.setProperty(k, v) } - - java.sql.DriverManager.getConnection(s"jdbc:postgresql://$host:$port/$database", properties) - } - zenv <- make(acquire).build - } yield zenv.get[ZConnectionPool] - } + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect("org.postgresql.Driver", s"jdbc:postgresql://$host:$port/$database", props) def sqlserver( host: String, port: Int, database: String, props: Map[String, String] - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = - ZLayer.scoped { - for { - _ <- ZIO.attempt(Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver")) - acquire = ZIO.attemptBlocking { - val properties = new java.util.Properties - - props.foreach { case (k, v) => properties.setProperty(k, v) } - - java.sql.DriverManager - .getConnection(s"jdbc:sqlserver://$host:$port;databaseName=$database", properties) - } - zenv <- make(acquire).build - } yield zenv.get[ZConnectionPool] - } + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect( + "com.microsoft.sqlserver.jdbc.SQLServerDriver", + s"jdbc:sqlserver://$host:$port;databaseName=$database", + props + ) def mysql( host: String, port: Int, database: String, props: Map[String, String] - ): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = + connect("com.mysql.cj.jdbc.Driver", s"jdbc:mysql://$host:$port/$database", props) + + def connect( + driverName: String, + url: String, + props: Map[String, String] + ): ZLayer[Any, ConnectionException, ZConnectionPool] = ZLayer.scoped { for { - _ <- ZIO.attempt(Class.forName("com.mysql.cj.jdbc.Driver")) + _ <- ZIO.attempt(Class.forName(driverName)).refineOrDie { case e: ClassNotFoundException => + DriverNotFound(e, driverName) + } acquire = ZIO.attemptBlocking { val properties = new java.util.Properties props.foreach { case (k, v) => properties.setProperty(k, v) } - java.sql.DriverManager - .getConnection(s"jdbc:mysql://$host:$port/$database", properties) + java.sql.DriverManager.getConnection(url, properties) + }.refineOrDie { + case e: SQLTimeoutException => ConnectionTimeout(e) + case e: SQLException => DBError(e) } - zenv <- make(acquire).build + zenv <- make(acquire).build.provideSome[Scope](ZLayer.succeed(ZConnectionPoolConfig.default)) } yield zenv.get[ZConnectionPool] } - def make(acquire: Task[Connection]): ZLayer[ZConnectionPoolConfig, Throwable, ZConnectionPool] = + def make( + acquire: IO[ConnectionException, Connection] + ): ZLayer[ZConnectionPoolConfig, ConnectionException, ZConnectionPool] = ZLayer.scoped { for { config <- ZIO.service[ZConnectionPoolConfig] @@ -184,8 +141,8 @@ object ZConnectionPool { } yield connection } } yield new ZConnectionPool { - def transaction: ZLayer[Any, Throwable, ZConnection] = tx - def invalidate(conn: ZConnection): UIO[Any] = pool.invalidate(conn) + def transaction: ZLayer[Any, ConnectionException, ZConnection] = tx + def invalidate(conn: ZConnection): UIO[Any] = pool.invalidate(conn) } } } diff --git a/core/src/main/scala/zio/jdbc/ZConnectionPoolConfig.scala b/core/src/main/scala/zio/jdbc/ZConnectionPoolConfig.scala index 6e7c7c39..3deeadbe 100644 --- a/core/src/main/scala/zio/jdbc/ZConnectionPoolConfig.scala +++ b/core/src/main/scala/zio/jdbc/ZConnectionPoolConfig.scala @@ -23,7 +23,7 @@ import zio._ final case class ZConnectionPoolConfig( minConnections: Int, maxConnections: Int, - retryPolicy: Schedule[Any, Throwable, Any], + retryPolicy: Schedule[Any, ConnectionException, Any], timeToLive: Duration ) object ZConnectionPoolConfig { diff --git a/core/src/main/scala/zio/jdbc/ZResultSet.scala b/core/src/main/scala/zio/jdbc/ZResultSet.scala index d6b04b04..51893488 100644 --- a/core/src/main/scala/zio/jdbc/ZResultSet.scala +++ b/core/src/main/scala/zio/jdbc/ZResultSet.scala @@ -27,7 +27,8 @@ import java.sql.ResultSet */ final class ZResultSet(private[jdbc] val resultSet: ResultSet) { def access[A](f: ResultSet => A): ZIO[Any, Throwable, A] = ZIO.attemptBlocking(f(resultSet)) - def close: URIO[Any, Unit] = + + def close: URIO[Any, Unit] = ZIO.attempt(resultSet.close()).ignoreLogged private[jdbc] def next(): Boolean = resultSet.next() diff --git a/core/src/main/scala/zio/jdbc/package.scala b/core/src/main/scala/zio/jdbc/package.scala index 649c38ee..8e9b3cb5 100644 --- a/core/src/main/scala/zio/jdbc/package.scala +++ b/core/src/main/scala/zio/jdbc/package.scala @@ -30,7 +30,7 @@ package object jdbc { * A new transaction, which may be applied to ZIO effects that require a * connection in order to execute such effects in the transaction. */ - val transaction: ZLayer[ZConnectionPool, Throwable, ZConnection] = + val transaction: ZLayer[ZConnectionPool, ConnectionException, ZConnection] = ZLayer(ZIO.serviceWith[ZConnectionPool](_.transaction)).flatten private[jdbc] val currentTransactionIsolationLevel: FiberRef[Option[TransactionIsolationLevel]] = diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index c875d24a..caed84e2 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -1,12 +1,11 @@ package zio.jdbc import zio._ -import zio.test._ import zio.jdbc.SqlFragment.Setter import zio.jdbc.{ transaction => transact } import zio.schema.{ Schema, TypeId } import zio.test.Assertion._ -import java.sql.SQLException +import zio.test._ final case class Person(name: String, age: Int) final case class UserLogin(username: String, password: String) @@ -237,7 +236,7 @@ object SqlFragmentSpec extends ZIOSpecDefault { res <- transact(defectiveSql.execute).exit error <- ZTestLogger.logOutput.map(_.filter(log => log.logLevel == zio.LogLevel.Debug)) } yield assert(res)( - fails(isSubtype[SQLException](anything)) + fails(isSubtype[ZSQLException](anything)) ) && assert(error.head.annotations.keys)(contains("SQL")) && assert(error.head.message())(containsString(sqlString))) .provideLayer(ZConnectionPool.h2test.orDie) diff --git a/core/src/test/scala/zio/jdbc/TestConnection.scala b/core/src/test/scala/zio/jdbc/TestConnection.scala index d9678b38..21aaa29d 100644 --- a/core/src/test/scala/zio/jdbc/TestConnection.scala +++ b/core/src/test/scala/zio/jdbc/TestConnection.scala @@ -1,5 +1,9 @@ package zio.jdbc +import zio.RuntimeFlags + +import java.io.{ InputStream, Reader } +import java.net.URL import java.sql.{ Blob, CallableStatement, @@ -8,18 +12,15 @@ import java.sql.{ DatabaseMetaData, NClob, PreparedStatement, + ResultSet, SQLWarning, SQLXML, Savepoint, Statement, - Struct, - ResultSet + Struct } import java.util.{ Properties, concurrent } import java.{ sql, util } -import java.io.{ InputStream, Reader } -import java.net.URL -import zio.RuntimeFlags class TestConnection(failNext: Boolean = false, elems: Int = 0) extends Connection { self => @@ -174,7 +175,7 @@ class DummyPreparedStatement(failNext: Boolean, elemns: Int) extends PreparedSta override def executeUpdate(sql: String) = ??? - override def close() = closed = true + override def close(): Unit = closed = true override def getMaxFieldSize() = ??? @@ -377,7 +378,7 @@ class DummyResultSet(failNext: Boolean, elems: Int) extends ResultSet { override def isWrapperFor(x$1: Class[_]) = ??? - override def next() = + override def next(): Boolean = if (failNext) { throw new sql.SQLException() } else if (currentElem < elems) { @@ -387,7 +388,7 @@ class DummyResultSet(failNext: Boolean, elems: Int) extends ResultSet { false } - override def close() = closed = true + override def close(): Unit = closed = true override def wasNull() = ??? diff --git a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala index 3c87cc94..790048ac 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala @@ -7,8 +7,8 @@ import zio.test.Assertion._ import zio.test.TestAspect._ import zio.test._ -import scala.util.Random import java.sql.Connection +import scala.util.Random object ZConnectionPoolSpec extends ZIOSpecDefault { final case class Person(name: String, age: Int) diff --git a/core/src/test/scala/zio/jdbc/ZConnectionSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionSpec.scala index 20360335..9ca2822b 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionSpec.scala @@ -2,6 +2,7 @@ package zio.jdbc import zio.test._ import zio.{ Random, ZIO } + import java.sql.PreparedStatement object ZConnectionSpec extends ZIOSpecDefault { @@ -27,17 +28,16 @@ object ZConnectionSpec extends ZIOSpecDefault { test("PreparedStatement Automatic Close Fail") { ZIO.scoped { for { - statementClosedTuple <- - testConnection - .executeSqlWith( - sql""" + res <- testConnection + .executeSqlWith( + sql""" create table users_no_id ( name varchar not null, age int not null )""", - false - )(ps => ZIO.fail(new DummyException("Error Ocurred", ps, ps.isClosed()))) - .catchSome { case e: DummyException => ZIO.succeed((e.preparedStatement, e.closedInScope)) } + false + )(ps => ZIO.succeed(new DummyException("Error Ocurred", ps, ps.isClosed()))) + statementClosedTuple <- ZIO.succeed((res.preparedStatement, res.closedInScope)) } yield assertTrue(statementClosedTuple._1.isClosed() && !statementClosedTuple._2) } //A bit of a hack, DummyException receives the prepared Statement so that its closed State can be checked outside ZConnection's Scope } @@ -72,7 +72,7 @@ object ZConnectionSpec extends ZIOSpecDefault { } } - class DummyException(msg: String, val preparedStatement: PreparedStatement, val closedInScope: Boolean) + case class DummyException(msg: String, val preparedStatement: PreparedStatement, val closedInScope: Boolean) extends Exception(msg) }