2025-10-27 15:24:08 +08:00

379 lines
9.8 KiB
Go

// Package sqldb provides support for access the database.
package sqldb
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"management/internal/pkg/config"
"github.com/drhin/logger"
"github.com/jackc/pgx/v5/pgconn"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
)
// lib/pq errorCodeNames
// https://github.com/lib/pq/blob/master/error.go#L178
const (
uniqueViolation = "23505"
undefinedTable = "42P01"
)
// Set of error variables for CRUD operations.
var (
ErrDBNotFound = sql.ErrNoRows
ErrDBDuplicatedEntry = errors.New("duplicated entry")
ErrUndefinedTable = errors.New("undefined table")
)
// Config is the required properties to use the database.
//type Config struct {
// User string
// Password string
// Host string
// Name string
// Schema string
// MaxIdleConns int
// MaxOpenConns int
// DisableTLS bool
//}
// Open knows how to open a database connection based on the configuration.
//func Open(cfg Config) (*sqlx.DB, error) {
// sslMode := "require"
// if cfg.DisableTLS {
// sslMode = "disable"
// }
//
// q := make(url.Values)
// q.Set("sslmode", sslMode)
// q.Set("timezone", "utc")
// if cfg.Schema != "" {
// q.Set("search_path", cfg.Schema)
// }
//
// u := url.URL{
// Scheme: "postgres",
// User: url.UserPassword(cfg.User, cfg.Password),
// Host: cfg.Host,
// Path: cfg.Name,
// RawQuery: q.Encode(),
// }
//
// db, err := sqlx.Open("pgx", u.String())
// if err != nil {
// return nil, err
// }
// db.SetMaxIdleConns(cfg.MaxIdleConns)
// db.SetMaxOpenConns(cfg.MaxOpenConns)
//
// return db, nil
//}
type Config struct {
User string
Password string
Host string
Port int
Name string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
config.DB.Username,
config.DB.Password,
config.DB.Host,
config.DB.Port,
config.DB.DBName,
)
db, err := sqlx.Open("pgx", dsn)
if err != nil {
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, nil, err
}
// 设置最大空闲连接数(默认 2)
db.SetMaxIdleConns(config.DB.MaxIdleConns)
// 设置最大打开连接数(默认 0 无限制)
db.SetMaxOpenConns(config.DB.MaxOpenConns)
// 设置连接最大存活时间
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
// 设置连接最大空闲时间
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
cleanup := func() {
if err := db.Close(); err != nil {
log.Error("sql db close error", err)
}
}
return db, cleanup, nil
}
// StatusCheck returns nil if it can successfully talk to the database. It
// returns a non-nil error otherwise.
func StatusCheck(ctx context.Context, db *sqlx.DB) error {
// If the user doesn't give us a deadline set 1 second.
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second)
defer cancel()
}
for attempts := 1; ; attempts++ {
if err := db.Ping(); err == nil {
break
}
time.Sleep(time.Duration(attempts) * 100 * time.Millisecond)
if ctx.Err() != nil {
return ctx.Err()
}
}
if ctx.Err() != nil {
return ctx.Err()
}
// Run a simple query to determine connectivity.
// Running this query forces a round trip through the database.
const q = `SELECT TRUE`
var tmp bool
return db.QueryRowContext(ctx, q).Scan(&tmp)
}
// ExecContext is a helper function to execute a CUD operation with
// logging and tracing.
func ExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string) error {
return NamedExecContext(ctx, log, db, query, struct{}{})
}
// NamedExecContext is a helper function to execute a CUD operation with
// logging and tracing where field replacement is necessary.
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
q := queryString(query, data)
defer func() {
if err != nil {
switch data.(type) {
case struct{}:
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
default:
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 5), zap.Error(err))
}
}
}()
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
var pqerr *pgconn.PgError
if errors.As(err, &pqerr) {
switch pqerr.Code {
case undefinedTable:
return ErrUndefinedTable
case uniqueViolation:
return ErrDBDuplicatedEntry
}
}
return err
}
return nil
}
// QuerySlice is a helper function for executing queries that return a
// collection of data to be unmarshalled into a slice.
func QuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, struct{}{}, dest, false)
}
// NamedQuerySlice is a helper function for executing queries that return a
// collection of data to be unmarshalled into a slice where field replacement is
// necessary.
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, data, dest, false)
}
// NamedQuerySliceUsingIn is a helper function for executing queries that return
// a collection of data to be unmarshalled into a slice where field replacement
// is necessary. Use this if the query has an IN clause.
func NamedQuerySliceUsingIn[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, data, dest, true)
}
func namedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T, withIn bool) (err error) {
q := queryString(query, data)
defer func() {
if err != nil {
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
}
}()
var rows *sqlx.Rows
switch withIn {
case true:
rows, err = func() (*sqlx.Rows, error) {
named, args, err := sqlx.Named(query, data)
if err != nil {
return nil, err
}
query, args, err := sqlx.In(named, args...)
if err != nil {
return nil, err
}
query = db.Rebind(query)
return db.QueryxContext(ctx, query, args...)
}()
default:
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
}
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
return err
}
defer func(rows *sqlx.Rows) {
_ = rows.Close()
}(rows)
var slice []T
for rows.Next() {
v := new(T)
if err := rows.StructScan(v); err != nil {
return err
}
slice = append(slice, *v)
}
*dest = slice
return nil
}
// QueryStruct is a helper function for executing queries that return a
// single value to be unmarshalled into a struct type where field replacement is necessary.
func QueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest any) error {
return namedQueryStruct(ctx, log, db, query, struct{}{}, dest, false)
}
// NamedQueryStruct is a helper function for executing queries that return a
// single value to be unmarshalled into a struct type where field replacement is necessary.
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
return namedQueryStruct(ctx, log, db, query, data, dest, false)
}
// NamedQueryStructUsingIn is a helper function for executing queries that return
// a single value to be unmarshalled into a struct type where field replacement
// is necessary. Use this if the query has an IN clause.
func NamedQueryStructUsingIn(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
return namedQueryStruct(ctx, log, db, query, data, dest, true)
}
func namedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any, withIn bool) (err error) {
q := queryString(query, data)
defer func() {
if err != nil {
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
}
}()
var rows *sqlx.Rows
switch withIn {
case true:
rows, err = func() (*sqlx.Rows, error) {
named, args, err := sqlx.Named(query, data)
if err != nil {
return nil, err
}
query, args, err := sqlx.In(named, args...)
if err != nil {
return nil, err
}
query = db.Rebind(query)
return db.QueryxContext(ctx, query, args...)
}()
default:
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
}
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
return err
}
defer func(rows *sqlx.Rows) {
_ = rows.Close()
}(rows)
if !rows.Next() {
return ErrDBNotFound
}
if err := rows.StructScan(dest); err != nil {
return err
}
return nil
}
// queryString provides a pretty print version of the query and parameters.
func queryString(query string, args any) string {
query, params, err := sqlx.Named(query, args)
if err != nil {
return err.Error()
}
for _, param := range params {
var value string
switch v := param.(type) {
case string:
value = fmt.Sprintf("'%s'", v)
case []byte:
value = fmt.Sprintf("'%s'", string(v))
default:
value = fmt.Sprintf("%v", v)
}
query = strings.Replace(query, "?", value, 1)
}
query = strings.ReplaceAll(query, "\t", "")
query = strings.ReplaceAll(query, "\n", " ")
return strings.Trim(query, " ")
}