diff --git a/src/core/repos/action_token.go b/src/core/repos/action_token.go index df2c26d..b8c20a7 100644 --- a/src/core/repos/action_token.go +++ b/src/core/repos/action_token.go @@ -2,8 +2,8 @@ package repos import ( "backend/src/core/models" + "backend/src/integrations" "context" - "database/sql" ) type ActionTokenRepo interface { @@ -11,14 +11,14 @@ type ActionTokenRepo interface { PopActionToken(ctx context.Context, userId, value string, target models.ActionTokenTarget) (*models.ActionTokenDTO, error) } -func NewActionTokenRepo(db *sql.DB) ActionTokenRepo { +func NewActionTokenRepo(db integrations.SqlDB) ActionTokenRepo { return &actionTokenRepo{ db: db, } } type actionTokenRepo struct { - db *sql.DB + db integrations.SqlDB } func (a *actionTokenRepo) CreateActionToken(ctx context.Context, dto models.ActionTokenDTO) (*models.ActionTokenDTO, error) { diff --git a/src/core/repos/user_repo.go b/src/core/repos/user_repo.go index d64db63..e9b0c90 100644 --- a/src/core/repos/user_repo.go +++ b/src/core/repos/user_repo.go @@ -2,6 +2,7 @@ package repos import ( "backend/src/core/models" + "backend/src/integrations" "context" "database/sql" "errors" @@ -21,12 +22,12 @@ type UserRepo interface { GetUserByEmail(ctx context.Context, login string) (*models.UserDTO, error) } -func NewUserRepo(db *sql.DB) UserRepo { +func NewUserRepo(db integrations.SqlDB) UserRepo { return &userRepo{db} } type userRepo struct { - db *sql.DB + db integrations.SqlDB } func (u *userRepo) CreateUser(ctx context.Context, dto models.UserDTO) (*models.UserDTO, error) { diff --git a/src/integrations/postgresql.go b/src/integrations/postgresql.go index de727f3..c2db00f 100644 --- a/src/integrations/postgresql.go +++ b/src/integrations/postgresql.go @@ -3,23 +3,65 @@ package integrations import ( "context" "database/sql" + "database/sql/driver" "fmt" + "net/url" + "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/stdlib" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" ) -// TODO: wrapper, connection pool -func NewPostgresConn(ctx context.Context, connUrl string) (*sql.DB, error) { - connConf, err := pgx.ParseConnectionString(connUrl) +type SqlDB interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + Close() error + Conn(ctx context.Context) (*sql.Conn, error) + Driver() driver.Driver + Exec(query string, args ...any) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + Ping() error + PingContext(ctx context.Context) error + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Query(query string, args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRow(query string, args ...any) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + SetConnMaxIdleTime(d time.Duration) + SetConnMaxLifetime(d time.Duration) + SetMaxIdleConns(n int) + SetMaxOpenConns(n int) + Stats() sql.DBStats +} + +type Postgres struct { + *sql.DB + pool *pgxpool.Pool +} + +func NewPostgresConn(ctx context.Context, postgresUrl string) (SqlDB, error) { + connUrl, err := url.Parse(postgresUrl) if err != nil { - return nil, fmt.Errorf("failed parsing postgres connection string: %v", err) + return nil, err } - sqlDb := stdlib.OpenDB(connConf) - if err := sqlDb.PingContext(ctx); err != nil { + query, _ := url.ParseQuery(connUrl.RawQuery) + query.Set("pool_max_conns", "16") + connUrl.RawQuery = query.Encode() + + pool, err := pgxpool.New(ctx, connUrl.String()) + if err != nil { + return nil, err + } + + db := stdlib.OpenDBFromPool(pool) + if err := db.PingContext(ctx); err != nil { return nil, fmt.Errorf("failed pinging postgres db: %v", err) } - return sqlDb, nil + return &Postgres{ + DB: db, + pool: pool, + }, nil }