diff --git a/.gitignore b/.gitignore index 4ec852f..4d2c35f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .vscode/ vendor/ +coverage.out diff --git a/constants.go b/constants.go index e2269d8..b75fa39 100644 --- a/constants.go +++ b/constants.go @@ -48,23 +48,6 @@ const ( DialectMSSQL Dialect = "mssql" ) -// PostgresErrCode defines the type for Postgres error codes. -type PostgresErrCode string - -// Postgres error codes (will be filled gradually). -const ( - PgxErrCodeUniqueViolation PostgresErrCode = "23505" - PgxErrCodeDeadlockDetected PostgresErrCode = "40P01" - PgxErrCodeSerializationFailure PostgresErrCode = "40001" - PgxErrFeatureNotSupported PostgresErrCode = "0A000" - - // nolint: staticcheck // lib/pq using is deprecated. Use pgx Postgres driver. - PostgresErrCodeUniqueViolation PostgresErrCode = "unique_violation" - // nolint: staticcheck // lib/pq using is deprecated. Use pgx Postgres driver. - PostgresErrCodeDeadlockDetected PostgresErrCode = "deadlock_detected" - PostgresErrCodeSerializationFailure PostgresErrCode = "serialization_failure" -) - // PostgresSSLMode defines possible values for Postgres sslmode connection parameter. type PostgresSSLMode string diff --git a/dbrutil/retry_test.go b/dbrutil/retry_test.go index 68ee91b..a360eaa 100644 --- a/dbrutil/retry_test.go +++ b/dbrutil/retry_test.go @@ -16,14 +16,14 @@ import ( "github.com/stretchr/testify/require" "github.com/acronis/go-dbkit" - _ "github.com/acronis/go-dbkit/pgx" + "github.com/acronis/go-dbkit/pgx" ) // Test that retriable errors stays retriable even wrapped in Tx structures func TestTxErrorsIsRetriable(t *testing.T) { - retriable := []dbkit.PostgresErrCode{ - dbkit.PgxErrCodeDeadlockDetected, - dbkit.PgxErrCodeSerializationFailure, + retriable := []pgx.ErrCode{ + pgx.ErrCodeDeadlockDetected, + pgx.ErrCodeSerializationFailure, } mkerr := func(code string) []error { diff --git a/mssql/mssql.go b/mssql/mssql.go index f47170d..dd7e217 100644 --- a/mssql/mssql.go +++ b/mssql/mssql.go @@ -24,7 +24,7 @@ func init() { dbkit.RegisterIsRetryableFunc(&mssql.Driver{}, func(err error) bool { var msErr mssql.Error if errors.As(err, &msErr) { - if msErr.Number == int32(MSSQLErrDeadlock) { // deadlock error + if msErr.Number == int32(ErrDeadlock) { // deadlock error return true } } @@ -37,12 +37,13 @@ type ErrCode int32 // MSSQL error codes (will be filled gradually). const ( - MSSQLErrDeadlock ErrCode = 1205 - MSSQLErrCodeUniqueViolation ErrCode = 2627 - MSSQLErrCodeUniqueIndexViolation ErrCode = 2601 + ErrDeadlock ErrCode = 1205 + ErrCodeUniqueViolation ErrCode = 2627 + ErrCodeUniqueIndexViolation ErrCode = 2601 ) -// CheckMSSQLError checks if the passed error relates to MSSQL and it's internal code matches the one from the argument. +// CheckMSSQLError checks if the passed error relates to MSSQL, +// and it's internal code matches the one from the argument. func CheckMSSQLError(err error, errCode ErrCode) bool { var msErr mssql.Error if errors.As(err, &msErr) { diff --git a/mysql/mysql.go b/mysql/mysql.go index dcf36f4..cbda613 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH. Released under MIT license. */ -// Package mysql provides helpers for working MySQL database. +// Package mysql provides helpers for working with the MySQL database using the github.com/go-sql-driver/mysql driver. // Should be imported explicitly. // To register mysql as retryable func use side effect import like so: // @@ -25,7 +25,7 @@ func init() { var mySQLError *mysql.MySQLError if errors.As(err, &mySQLError) { switch mySQLError.Number { - case uint16(MySQLErrDeadlock), uint16(MySQLErrLockTimedOut): + case uint16(ErrDeadlock), uint16(ErrLockTimedOut): return true } } @@ -36,21 +36,21 @@ func init() { }) } -// MySQLErrCode defines the type for MySQL error codes. -// nolint: revive -type MySQLErrCode uint16 +// ErrCode defines the type for MySQL error codes. +type ErrCode uint16 // MySQL error codes (will be filled gradually). const ( - MySQLErrCodeDupEntry MySQLErrCode = 1062 - MySQLErrDeadlock MySQLErrCode = 1213 - MySQLErrLockTimedOut MySQLErrCode = 1205 + ErrCodeDupEntry ErrCode = 1062 + ErrDeadlock ErrCode = 1213 + ErrLockTimedOut ErrCode = 1205 ) -// CheckMySQLError checks if the passed error relates to MySQL and it's internal code matches the one from the argument. -func CheckMySQLError(err error, errCode MySQLErrCode) bool { +// CheckMySQLError checks if the passed error relates to MySQL, +// and it's internal code matches the one from the argument. +func CheckMySQLError(err error, errCode ErrCode) bool { var mySQLError *mysql.MySQLError - if ok := errors.As(err, &mySQLError); ok { + if errors.As(err, &mySQLError) { return mySQLError.Number == uint16(errCode) } return false diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go index 3c64654..13a9f29 100644 --- a/mysql/mysql_test.go +++ b/mysql/mysql_test.go @@ -34,21 +34,21 @@ func TestMysqlIsRetryable(t *testing.T) { isRetryable := dbkit.GetIsRetryable(&mysql.MySQLDriver{}) require.NotNil(t, isRetryable) require.True(t, isRetryable(&mysql.MySQLError{ - Number: uint16(MySQLErrDeadlock), + Number: uint16(ErrDeadlock), })) require.True(t, isRetryable(&mysql.MySQLError{ - Number: uint16(MySQLErrLockTimedOut), + Number: uint16(ErrLockTimedOut), })) require.True(t, isRetryable(mysql.ErrInvalidConn)) require.False(t, isRetryable(driver.ErrBadConn)) require.True(t, isRetryable(fmt.Errorf("wrapped error: %w", &mysql.MySQLError{ - Number: uint16(MySQLErrDeadlock), + Number: uint16(ErrDeadlock), }))) } // TestCheckMySQLError covers behavior of CheckMySQLError func. func TestCheckMySQLError(t *testing.T) { - var deadlockErr MySQLErrCode = 1213 + var deadlockErr ErrCode = 1213 sqlErr := &mysql.MySQLError{ Number: 1213, Message: "deadlock found when trying to get lock", diff --git a/pgx/deadlock_test.go b/pgx/deadlock_test.go index f6f9bd2..403b333 100644 --- a/pgx/deadlock_test.go +++ b/pgx/deadlock_test.go @@ -16,6 +16,6 @@ import ( func TestDeadlockErrorHandling(t *gotesting.T) { testing.DeadlockTest(t, dbkit.DialectPgx, func(err error) bool { - return CheckPostgresError(err, dbkit.PgxErrCodeDeadlockDetected) + return CheckPostgresError(err, ErrCodeDeadlockDetected) }) } diff --git a/pgx/postgres.go b/pgx/pgx.go similarity index 75% rename from pgx/postgres.go rename to pgx/pgx.go index 3bf0266..c44af1e 100644 --- a/pgx/postgres.go +++ b/pgx/pgx.go @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH. Released under MIT license. */ -// Package pgx provides helpers for working Postgres database via jackc/pgx driver. +// Package pgx provides helpers for working with the Postgres database using the github.com/jackc/pgx driver. // Should be imported explicitly. // To register postgres as retryable func use side effect import like so: // @@ -25,10 +25,10 @@ func init() { dbkit.RegisterIsRetryableFunc(&pg.Driver{}, func(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { - switch errCode := dbkit.PostgresErrCode(pgErr.Code); errCode { - case dbkit.PgxErrCodeDeadlockDetected: + switch errCode := ErrCode(pgErr.Code); errCode { + case ErrCodeDeadlockDetected: return true - case dbkit.PgxErrCodeSerializationFailure: + case ErrCodeSerializationFailure: return true } if checkInvalidCachedPlanPgError(pgErr) { @@ -39,12 +39,22 @@ func init() { }) } +// ErrCode defines the type for Pgx error codes. +type ErrCode string + +// Pgx error codes (will be filled gradually). +const ( + ErrCodeUniqueViolation ErrCode = "23505" + ErrCodeDeadlockDetected ErrCode = "40P01" + ErrCodeSerializationFailure ErrCode = "40001" + ErrFeatureNotSupported ErrCode = "0A000" +) + // CheckPostgresError checks if the passed error relates to Postgres, // and it's internal code matches the one from the argument. -func CheckPostgresError(err error, errCode dbkit.PostgresErrCode) bool { +func CheckPostgresError(err error, errCode ErrCode) bool { var pgErr *pgconn.PgError - ok := errors.As(err, &pgErr) - if ok { + if errors.As(err, &pgErr) { return pgErr.Code == string(errCode) } return false @@ -69,6 +79,6 @@ func CheckInvalidCachedPlanError(err error) bool { // Source: https://github.com/jackc/pgconn/blob/9cf57526250f6cd3e6cbf4fd7269c882e66898ce/stmtcache/lru.go#L91-L103 func checkInvalidCachedPlanPgError(pgErr *pgconn.PgError) bool { return pgErr.Severity == "ERROR" && - pgErr.Code == string(dbkit.PgxErrFeatureNotSupported) && + pgErr.Code == string(ErrFeatureNotSupported) && pgErr.Message == "cached plan must not change result type" } diff --git a/pgx/postgres_test.go b/pgx/pgx_test.go similarity index 97% rename from pgx/postgres_test.go rename to pgx/pgx_test.go index f1bcff2..a5923a1 100644 --- a/pgx/postgres_test.go +++ b/pgx/pgx_test.go @@ -67,9 +67,9 @@ func TestPostgresIsRetryable(t *gotesting.T) { isRetryable := dbkit.GetIsRetryable(&pg.Driver{}) require.NotNil(t, isRetryable) // enum all retriable errors - retriable := []dbkit.PostgresErrCode{ - dbkit.PgxErrCodeDeadlockDetected, - dbkit.PgxErrCodeSerializationFailure, + retriable := []ErrCode{ + ErrCodeDeadlockDetected, + ErrCodeSerializationFailure, } for _, code := range retriable { var err error diff --git a/postgres/deadlock_test.go b/postgres/deadlock_test.go index 5de4abe..f33e494 100644 --- a/postgres/deadlock_test.go +++ b/postgres/deadlock_test.go @@ -16,6 +16,6 @@ import ( func TestDeadlockErrorHandling(t *gotesting.T) { testing.DeadlockTest(t, dbkit.DialectPostgres, func(err error) bool { - return CheckPostgresError(err, dbkit.PostgresErrCodeDeadlockDetected) + return CheckPostgresError(err, ErrCodeDeadlockDetected) }) } diff --git a/postgres/postgres.go b/postgres/postgres.go index f405b71..d936311 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -4,7 +4,7 @@ Copyright © 2024 Acronis International GmbH. Released under MIT license. */ -// Package postgres provides helpers for working Postgres database. +// Package postgres provides helpers for working with the Postgres database using the github.com/lib/pq driver. // Should be imported explicitly. // To register postgres as retryable func use side effect import like so: // @@ -14,21 +14,21 @@ package postgres import ( "errors" - pg "github.com/lib/pq" + "github.com/lib/pq" "github.com/acronis/go-dbkit" ) // nolint func init() { - dbkit.RegisterIsRetryableFunc(&pg.Driver{}, func(err error) bool { - var pgErr *pg.Error + dbkit.RegisterIsRetryableFunc(&pq.Driver{}, func(err error) bool { + var pgErr *pq.Error if errors.As(err, &pgErr) { - name := dbkit.PostgresErrCode(pgErr.Code.Name()) + name := ErrCode(pgErr.Code.Name()) switch name { - case dbkit.PostgresErrCodeDeadlockDetected: + case ErrCodeDeadlockDetected: return true - case dbkit.PostgresErrCodeSerializationFailure: + case ErrCodeSerializationFailure: return true } } @@ -36,10 +36,20 @@ func init() { }) } -// CheckPostgresError checks if the passed error relates to Postgres and it's internal code matches the one from the argument. -// nolint: staticcheck // lib/pq using is deprecated. Use pgx Postgres driver. -func CheckPostgresError(err error, errCode dbkit.PostgresErrCode) bool { - var pgErr *pg.Error +// ErrCode defines the type for Postgres error codes. +type ErrCode string + +// Postgres error codes (will be filled gradually). +const ( + ErrCodeUniqueViolation ErrCode = "unique_violation" + ErrCodeDeadlockDetected ErrCode = "deadlock_detected" + ErrCodeSerializationFailure ErrCode = "serialization_failure" +) + +// CheckPostgresError checks if the passed error relates to Postgres, +// and it's internal code matches the one from the argument. +func CheckPostgresError(err error, errCode ErrCode) bool { + var pgErr *pq.Error if errors.As(err, &pgErr) { return pgErr.Code.Name() == string(errCode) } diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index 254a173..371f07c 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -13,7 +13,8 @@ package sqlite import ( "errors" - sqlite3 "github.com/mattn/go-sqlite3" + + "github.com/mattn/go-sqlite3" "github.com/acronis/go-dbkit" ) @@ -32,7 +33,8 @@ func init() { }) } -// CheckSQLiteError checks if the passed error relates to SQLite and it's internal code matches the one from the argument. +// CheckSQLiteError checks if the passed error relates to SQLite, +// and it's internal code matches the one from the argument. func CheckSQLiteError(err error, errCode sqlite3.ErrNoExtended) bool { var sqliteErr sqlite3.Error if errors.As(err, &sqliteErr) {