package sqldb import ( "context" "database/sql" "errors" "fmt" "strings" "time" "management/internal/pkg/config" "github.com/drhin/logger" "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" "go.uber.org/zap" ) const ( uniqueViolation = "23505" undefinedTable = "42P01" ) var ( ErrDBNotFound = sql.ErrNoRows ErrDBDuplicatedEntry = errors.New("duplicated entry") ErrUndefinedTable = errors.New("undefined table") ) type Config struct { User string Password string Host string Port int Name string MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration ConnMaxIdleTime time.Duration } func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) { dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", config.DB.Username, config.DB.Password, config.DB.Host, config.DB.Port, config.DB.DBName, ) db, err := sqlx.Open("pgx", dsn) if err != nil { return nil, nil, fmt.Errorf("sqlx open db: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() if err := db.PingContext(ctx); err != nil { return nil, nil, err } // 设置最大空闲连接数(默认 2) db.SetMaxIdleConns(config.DB.MaxIdleConns) // 设置最大打开连接数(默认 0 无限制) db.SetMaxOpenConns(config.DB.MaxOpenConns) // 设置连接最大存活时间 db.SetConnMaxLifetime(config.DB.ConnMaxLifetime) // 设置连接最大空闲时间 db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime) cleanup := func() { if err := db.Close(); err != nil { log.Error("sql db close error", err) } } return db, cleanup, nil } func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) { defer func() { if err != nil { switch data.(type) { case struct{}: log.Error("database.NamedExecContext (data is struct)", err, zap.String("query", query), zap.Any("ERROR", err)) default: log.Error("database.NamedExecContext", err, zap.String("query", query), zap.Any("ERROR", err)) } } }() if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil { var pgError *pgconn.PgError if errors.As(err, &pgError) { switch pgError.Code { case undefinedTable: return ErrUndefinedTable case uniqueViolation: return ErrDBDuplicatedEntry } } return err } return nil } func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) { q := queryString(query, data) rows, err := sqlx.NamedQueryContext(ctx, db, q, data) if err != nil { var pqErr *pgconn.PgError if errors.As(err, &pqErr) && pqErr.Code == undefinedTable { return ErrUndefinedTable } log.Error("NamedQueryStruct NamedQueryContext error", err, zap.String("query", q), zap.Any("data", data), ) return err } defer func(rows *sqlx.Rows) { err := rows.Close() if err != nil { log.Error("rows close error", err) } }(rows) if !rows.Next() { return ErrDBNotFound } if err := rows.StructScan(dest); err != nil { log.Error("NamedQueryStruct StructScan error", err, zap.String("query", q), zap.Any("data", data), ) return err } return nil } func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) { q := queryString(query, data) rows, err := sqlx.NamedQueryContext(ctx, db, q, data) if err != nil { var pqErr *pgconn.PgError if errors.As(err, &pqErr) && pqErr.Code == undefinedTable { return ErrUndefinedTable } log.Error("NamedQueryStruct NamedQueryContext error", err, zap.String("query", q), zap.Any("data", data), ) return err } defer func(rows *sqlx.Rows) { err := rows.Close() if err != nil { log.Error("rows close error", err) } }(rows) var slice []T for rows.Next() { v := new(T) if err := rows.StructScan(v); err != nil { log.Error("NamedQuerySlice StructScan error", err, zap.String("query", q), zap.Any("data", data), ) return err } slice = append(slice, *v) } *dest = slice return nil } func queryString(query string, args any) string { query, params, err := sqlx.Named(query, args) if err != nil { return err.Error() } for _, param := range params { var value string switch v := param.(type) { case string: value = fmt.Sprintf("'%s'", v) case []byte: value = fmt.Sprintf("'%s'", string(v)) default: value = fmt.Sprintf("%v", v) } query = strings.Replace(query, "?", value, 1) } query = strings.ReplaceAll(query, "\t", "") query = strings.ReplaceAll(query, "\n", " ") return strings.Trim(query, " ") }