2025-06-18 17:50:02 +08:00

215 lines
4.7 KiB
Go

package sqldb
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"management/internal/pkg/config"
"github.com/drhin/logger"
"github.com/jackc/pgx/v5/pgconn"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
)
const (
uniqueViolation = "23505"
undefinedTable = "42P01"
)
var (
ErrDBNotFound = sql.ErrNoRows
ErrDBDuplicatedEntry = errors.New("duplicated entry")
ErrUndefinedTable = errors.New("undefined table")
)
type Config struct {
User string
Password string
Host string
Port int
Name string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
config.DB.Username,
config.DB.Password,
config.DB.Host,
config.DB.Port,
config.DB.DBName,
)
db, err := sqlx.Open("pgx", dsn)
if err != nil {
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, nil, err
}
// 设置最大空闲连接数(默认 2)
db.SetMaxIdleConns(config.DB.MaxIdleConns)
// 设置最大打开连接数(默认 0 无限制)
db.SetMaxOpenConns(config.DB.MaxOpenConns)
// 设置连接最大存活时间
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
// 设置连接最大空闲时间
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
cleanup := func() {
if err := db.Close(); err != nil {
log.Error("sql db close error", err)
}
}
return db, cleanup, nil
}
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
defer func() {
if err != nil {
switch data.(type) {
case struct{}:
log.Error("database.NamedExecContext (data is struct)", err,
zap.String("query", query),
zap.Any("ERROR", err))
default:
log.Error("database.NamedExecContext", err,
zap.String("query", query),
zap.Any("ERROR", err))
}
}
}()
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
var pgError *pgconn.PgError
if errors.As(err, &pgError) {
switch pgError.Code {
case undefinedTable:
return ErrUndefinedTable
case uniqueViolation:
return ErrDBDuplicatedEntry
}
}
return err
}
return nil
}
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
}(rows)
if !rows.Next() {
return ErrDBNotFound
}
if err := rows.StructScan(dest); err != nil {
log.Error("NamedQueryStruct StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
return nil
}
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
}(rows)
var slice []T
for rows.Next() {
v := new(T)
if err := rows.StructScan(v); err != nil {
log.Error("NamedQuerySlice StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
slice = append(slice, *v)
}
*dest = slice
return nil
}
func queryString(query string, args any) string {
query, params, err := sqlx.Named(query, args)
if err != nil {
return err.Error()
}
for _, param := range params {
var value string
switch v := param.(type) {
case string:
value = fmt.Sprintf("'%s'", v)
case []byte:
value = fmt.Sprintf("'%s'", string(v))
default:
value = fmt.Sprintf("%v", v)
}
query = strings.Replace(query, "?", value, 1)
}
query = strings.ReplaceAll(query, "\t", "")
query = strings.ReplaceAll(query, "\n", " ")
return strings.Trim(query, " ")
}