215 lines
4.7 KiB
Go
215 lines
4.7 KiB
Go
package sqldb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"management/internal/pkg/config"
|
|
|
|
"github.com/drhin/logger"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
"github.com/jmoiron/sqlx"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
uniqueViolation = "23505"
|
|
undefinedTable = "42P01"
|
|
)
|
|
|
|
var (
|
|
ErrDBNotFound = sql.ErrNoRows
|
|
ErrDBDuplicatedEntry = errors.New("duplicated entry")
|
|
ErrUndefinedTable = errors.New("undefined table")
|
|
)
|
|
|
|
type Config struct {
|
|
User string
|
|
Password string
|
|
Host string
|
|
Port int
|
|
Name string
|
|
MaxIdleConns int
|
|
MaxOpenConns int
|
|
ConnMaxLifetime time.Duration
|
|
ConnMaxIdleTime time.Duration
|
|
}
|
|
|
|
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
|
|
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
|
|
config.DB.Username,
|
|
config.DB.Password,
|
|
config.DB.Host,
|
|
config.DB.Port,
|
|
config.DB.DBName,
|
|
)
|
|
|
|
db, err := sqlx.Open("pgx", dsn)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 设置最大空闲连接数(默认 2)
|
|
db.SetMaxIdleConns(config.DB.MaxIdleConns)
|
|
|
|
// 设置最大打开连接数(默认 0 无限制)
|
|
db.SetMaxOpenConns(config.DB.MaxOpenConns)
|
|
|
|
// 设置连接最大存活时间
|
|
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
|
|
|
|
// 设置连接最大空闲时间
|
|
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
|
|
|
|
cleanup := func() {
|
|
if err := db.Close(); err != nil {
|
|
log.Error("sql db close error", err)
|
|
}
|
|
}
|
|
|
|
return db, cleanup, nil
|
|
}
|
|
|
|
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
|
|
defer func() {
|
|
if err != nil {
|
|
switch data.(type) {
|
|
case struct{}:
|
|
log.Error("database.NamedExecContext (data is struct)", err,
|
|
zap.String("query", query),
|
|
zap.Any("ERROR", err))
|
|
default:
|
|
log.Error("database.NamedExecContext", err,
|
|
zap.String("query", query),
|
|
zap.Any("ERROR", err))
|
|
}
|
|
}
|
|
}()
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
|
|
var pgError *pgconn.PgError
|
|
if errors.As(err, &pgError) {
|
|
switch pgError.Code {
|
|
case undefinedTable:
|
|
return ErrUndefinedTable
|
|
case uniqueViolation:
|
|
return ErrDBDuplicatedEntry
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
|
|
q := queryString(query, data)
|
|
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
|
if err != nil {
|
|
var pqErr *pgconn.PgError
|
|
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
|
return ErrUndefinedTable
|
|
}
|
|
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
|
zap.String("query", q),
|
|
zap.Any("data", data),
|
|
)
|
|
return err
|
|
}
|
|
defer func(rows *sqlx.Rows) {
|
|
err := rows.Close()
|
|
if err != nil {
|
|
log.Error("rows close error", err)
|
|
}
|
|
}(rows)
|
|
|
|
if !rows.Next() {
|
|
return ErrDBNotFound
|
|
}
|
|
|
|
if err := rows.StructScan(dest); err != nil {
|
|
log.Error("NamedQueryStruct StructScan error", err,
|
|
zap.String("query", q),
|
|
zap.Any("data", data),
|
|
)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
|
|
q := queryString(query, data)
|
|
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
|
if err != nil {
|
|
var pqErr *pgconn.PgError
|
|
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
|
return ErrUndefinedTable
|
|
}
|
|
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
|
zap.String("query", q),
|
|
zap.Any("data", data),
|
|
)
|
|
return err
|
|
}
|
|
defer func(rows *sqlx.Rows) {
|
|
err := rows.Close()
|
|
if err != nil {
|
|
log.Error("rows close error", err)
|
|
}
|
|
}(rows)
|
|
|
|
var slice []T
|
|
for rows.Next() {
|
|
v := new(T)
|
|
if err := rows.StructScan(v); err != nil {
|
|
log.Error("NamedQuerySlice StructScan error", err,
|
|
zap.String("query", q),
|
|
zap.Any("data", data),
|
|
)
|
|
return err
|
|
}
|
|
slice = append(slice, *v)
|
|
}
|
|
*dest = slice
|
|
|
|
return nil
|
|
}
|
|
|
|
func queryString(query string, args any) string {
|
|
query, params, err := sqlx.Named(query, args)
|
|
if err != nil {
|
|
return err.Error()
|
|
}
|
|
|
|
for _, param := range params {
|
|
var value string
|
|
switch v := param.(type) {
|
|
case string:
|
|
value = fmt.Sprintf("'%s'", v)
|
|
case []byte:
|
|
value = fmt.Sprintf("'%s'", string(v))
|
|
default:
|
|
value = fmt.Sprintf("%v", v)
|
|
}
|
|
query = strings.Replace(query, "?", value, 1)
|
|
}
|
|
|
|
query = strings.ReplaceAll(query, "\t", "")
|
|
query = strings.ReplaceAll(query, "\n", " ")
|
|
|
|
return strings.Trim(query, " ")
|
|
}
|