sqlx
This commit is contained in:
214
internal/pkg/sqldb/sqldb.go
Normal file
214
internal/pkg/sqldb/sqldb.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"management/internal/pkg/config"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
uniqueViolation = "23505"
|
||||
undefinedTable = "42P01"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDBNotFound = sql.ErrNoRows
|
||||
ErrDBDuplicatedEntry = errors.New("duplicated entry")
|
||||
ErrUndefinedTable = errors.New("undefined table")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Port int
|
||||
Name string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
|
||||
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
config.DB.Username,
|
||||
config.DB.Password,
|
||||
config.DB.Host,
|
||||
config.DB.Port,
|
||||
config.DB.DBName,
|
||||
)
|
||||
|
||||
db, err := sqlx.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 设置最大空闲连接数(默认 2)
|
||||
db.SetMaxIdleConns(config.DB.MaxIdleConns)
|
||||
|
||||
// 设置最大打开连接数(默认 0 无限制)
|
||||
db.SetMaxOpenConns(config.DB.MaxOpenConns)
|
||||
|
||||
// 设置连接最大存活时间
|
||||
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
|
||||
|
||||
// 设置连接最大空闲时间
|
||||
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
|
||||
|
||||
cleanup := func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Error("sql db close error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, cleanup, nil
|
||||
}
|
||||
|
||||
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
switch data.(type) {
|
||||
case struct{}:
|
||||
log.Error("database.NamedExecContext (data is struct)", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
default:
|
||||
log.Error("database.NamedExecContext", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
|
||||
var pgError *pgconn.PgError
|
||||
if errors.As(err, &pgError) {
|
||||
switch pgError.Code {
|
||||
case undefinedTable:
|
||||
return ErrUndefinedTable
|
||||
case uniqueViolation:
|
||||
return ErrDBDuplicatedEntry
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
}(rows)
|
||||
|
||||
if !rows.Next() {
|
||||
return ErrDBNotFound
|
||||
}
|
||||
|
||||
if err := rows.StructScan(dest); err != nil {
|
||||
log.Error("NamedQueryStruct StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
}(rows)
|
||||
|
||||
var slice []T
|
||||
for rows.Next() {
|
||||
v := new(T)
|
||||
if err := rows.StructScan(v); err != nil {
|
||||
log.Error("NamedQuerySlice StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
slice = append(slice, *v)
|
||||
}
|
||||
*dest = slice
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func queryString(query string, args any) string {
|
||||
query, params, err := sqlx.Named(query, args)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
for _, param := range params {
|
||||
var value string
|
||||
switch v := param.(type) {
|
||||
case string:
|
||||
value = fmt.Sprintf("'%s'", v)
|
||||
case []byte:
|
||||
value = fmt.Sprintf("'%s'", string(v))
|
||||
default:
|
||||
value = fmt.Sprintf("%v", v)
|
||||
}
|
||||
query = strings.Replace(query, "?", value, 1)
|
||||
}
|
||||
|
||||
query = strings.ReplaceAll(query, "\t", "")
|
||||
query = strings.ReplaceAll(query, "\n", " ")
|
||||
|
||||
return strings.Trim(query, " ")
|
||||
}
|
||||
49
internal/pkg/sqldb/tran.go
Normal file
49
internal/pkg/sqldb/tran.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Beginner represents a value that can begin a transaction.
|
||||
type Beginner interface {
|
||||
Begin() (CommitRollbacker, error)
|
||||
}
|
||||
|
||||
// CommitRollbacker represents a value that can commit or rollback a transaction.
|
||||
type CommitRollbacker interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// DBBeginner implements the Beginner interface,
|
||||
type DBBeginner struct {
|
||||
sqlxDB *sqlx.DB
|
||||
}
|
||||
|
||||
// NewBeginner constructs a value that implements the beginner interface.
|
||||
func NewBeginner(sqlxDB *sqlx.DB) *DBBeginner {
|
||||
return &DBBeginner{
|
||||
sqlxDB: sqlxDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Begin implements the Beginner interface and returns a concrete value that
|
||||
// implements the CommitRollbacker interface.
|
||||
func (db *DBBeginner) Begin() (CommitRollbacker, error) {
|
||||
return db.sqlxDB.Beginx()
|
||||
}
|
||||
|
||||
// GetExtContext is a helper function that extracts the sqlx value
|
||||
// from the domain transactor interface for transactional use.
|
||||
func GetExtContext(tx CommitRollbacker) (sqlx.ExtContext, error) {
|
||||
ec, ok := tx.(sqlx.ExtContext)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Transactor(%T) not of a type *sql.Tx", tx)
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
Reference in New Issue
Block a user