379 lines
9.8 KiB
Go
379 lines
9.8 KiB
Go
// Package sqldb provides support for access the database.
|
|
package sqldb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"management/internal/pkg/config"
|
|
|
|
"github.com/drhin/logger"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
"github.com/jmoiron/sqlx"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// lib/pq errorCodeNames
|
|
// https://github.com/lib/pq/blob/master/error.go#L178
|
|
const (
|
|
uniqueViolation = "23505"
|
|
undefinedTable = "42P01"
|
|
)
|
|
|
|
// Set of error variables for CRUD operations.
|
|
var (
|
|
ErrDBNotFound = sql.ErrNoRows
|
|
ErrDBDuplicatedEntry = errors.New("duplicated entry")
|
|
ErrUndefinedTable = errors.New("undefined table")
|
|
)
|
|
|
|
// Config is the required properties to use the database.
|
|
//type Config struct {
|
|
// User string
|
|
// Password string
|
|
// Host string
|
|
// Name string
|
|
// Schema string
|
|
// MaxIdleConns int
|
|
// MaxOpenConns int
|
|
// DisableTLS bool
|
|
//}
|
|
|
|
// Open knows how to open a database connection based on the configuration.
|
|
//func Open(cfg Config) (*sqlx.DB, error) {
|
|
// sslMode := "require"
|
|
// if cfg.DisableTLS {
|
|
// sslMode = "disable"
|
|
// }
|
|
//
|
|
// q := make(url.Values)
|
|
// q.Set("sslmode", sslMode)
|
|
// q.Set("timezone", "utc")
|
|
// if cfg.Schema != "" {
|
|
// q.Set("search_path", cfg.Schema)
|
|
// }
|
|
//
|
|
// u := url.URL{
|
|
// Scheme: "postgres",
|
|
// User: url.UserPassword(cfg.User, cfg.Password),
|
|
// Host: cfg.Host,
|
|
// Path: cfg.Name,
|
|
// RawQuery: q.Encode(),
|
|
// }
|
|
//
|
|
// db, err := sqlx.Open("pgx", u.String())
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
// db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
//
|
|
// return db, nil
|
|
//}
|
|
|
|
type Config struct {
|
|
User string
|
|
Password string
|
|
Host string
|
|
Port int
|
|
Name string
|
|
MaxIdleConns int
|
|
MaxOpenConns int
|
|
ConnMaxLifetime time.Duration
|
|
ConnMaxIdleTime time.Duration
|
|
}
|
|
|
|
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
|
|
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
|
|
config.DB.Username,
|
|
config.DB.Password,
|
|
config.DB.Host,
|
|
config.DB.Port,
|
|
config.DB.DBName,
|
|
)
|
|
|
|
db, err := sqlx.Open("pgx", dsn)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 设置最大空闲连接数(默认 2)
|
|
db.SetMaxIdleConns(config.DB.MaxIdleConns)
|
|
|
|
// 设置最大打开连接数(默认 0 无限制)
|
|
db.SetMaxOpenConns(config.DB.MaxOpenConns)
|
|
|
|
// 设置连接最大存活时间
|
|
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
|
|
|
|
// 设置连接最大空闲时间
|
|
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
|
|
|
|
cleanup := func() {
|
|
if err := db.Close(); err != nil {
|
|
log.Error("sql db close error", err)
|
|
}
|
|
}
|
|
|
|
return db, cleanup, nil
|
|
}
|
|
|
|
// StatusCheck returns nil if it can successfully talk to the database. It
|
|
// returns a non-nil error otherwise.
|
|
func StatusCheck(ctx context.Context, db *sqlx.DB) error {
|
|
|
|
// If the user doesn't give us a deadline set 1 second.
|
|
if _, ok := ctx.Deadline(); !ok {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, time.Second)
|
|
defer cancel()
|
|
}
|
|
|
|
for attempts := 1; ; attempts++ {
|
|
if err := db.Ping(); err == nil {
|
|
break
|
|
}
|
|
|
|
time.Sleep(time.Duration(attempts) * 100 * time.Millisecond)
|
|
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
// Run a simple query to determine connectivity.
|
|
// Running this query forces a round trip through the database.
|
|
const q = `SELECT TRUE`
|
|
var tmp bool
|
|
return db.QueryRowContext(ctx, q).Scan(&tmp)
|
|
}
|
|
|
|
// ExecContext is a helper function to execute a CUD operation with
|
|
// logging and tracing.
|
|
func ExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string) error {
|
|
return NamedExecContext(ctx, log, db, query, struct{}{})
|
|
}
|
|
|
|
// NamedExecContext is a helper function to execute a CUD operation with
|
|
// logging and tracing where field replacement is necessary.
|
|
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
|
|
q := queryString(query, data)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
switch data.(type) {
|
|
case struct{}:
|
|
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
|
default:
|
|
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 5), zap.Error(err))
|
|
}
|
|
}
|
|
}()
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
|
|
var pqerr *pgconn.PgError
|
|
if errors.As(err, &pqerr) {
|
|
switch pqerr.Code {
|
|
case undefinedTable:
|
|
return ErrUndefinedTable
|
|
case uniqueViolation:
|
|
return ErrDBDuplicatedEntry
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// QuerySlice is a helper function for executing queries that return a
|
|
// collection of data to be unmarshalled into a slice.
|
|
func QuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest *[]T) error {
|
|
return namedQuerySlice(ctx, log, db, query, struct{}{}, dest, false)
|
|
}
|
|
|
|
// NamedQuerySlice is a helper function for executing queries that return a
|
|
// collection of data to be unmarshalled into a slice where field replacement is
|
|
// necessary.
|
|
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
|
|
return namedQuerySlice(ctx, log, db, query, data, dest, false)
|
|
}
|
|
|
|
// NamedQuerySliceUsingIn is a helper function for executing queries that return
|
|
// a collection of data to be unmarshalled into a slice where field replacement
|
|
// is necessary. Use this if the query has an IN clause.
|
|
func NamedQuerySliceUsingIn[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
|
|
return namedQuerySlice(ctx, log, db, query, data, dest, true)
|
|
}
|
|
|
|
func namedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T, withIn bool) (err error) {
|
|
q := queryString(query, data)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
|
}
|
|
}()
|
|
|
|
var rows *sqlx.Rows
|
|
|
|
switch withIn {
|
|
case true:
|
|
rows, err = func() (*sqlx.Rows, error) {
|
|
named, args, err := sqlx.Named(query, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query, args, err := sqlx.In(named, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query = db.Rebind(query)
|
|
return db.QueryxContext(ctx, query, args...)
|
|
}()
|
|
|
|
default:
|
|
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
|
|
}
|
|
|
|
if err != nil {
|
|
var pqErr *pgconn.PgError
|
|
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
|
return ErrUndefinedTable
|
|
}
|
|
return err
|
|
}
|
|
defer func(rows *sqlx.Rows) {
|
|
_ = rows.Close()
|
|
}(rows)
|
|
|
|
var slice []T
|
|
for rows.Next() {
|
|
v := new(T)
|
|
if err := rows.StructScan(v); err != nil {
|
|
return err
|
|
}
|
|
slice = append(slice, *v)
|
|
}
|
|
*dest = slice
|
|
|
|
return nil
|
|
}
|
|
|
|
// QueryStruct is a helper function for executing queries that return a
|
|
// single value to be unmarshalled into a struct type where field replacement is necessary.
|
|
func QueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest any) error {
|
|
return namedQueryStruct(ctx, log, db, query, struct{}{}, dest, false)
|
|
}
|
|
|
|
// NamedQueryStruct is a helper function for executing queries that return a
|
|
// single value to be unmarshalled into a struct type where field replacement is necessary.
|
|
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
|
|
return namedQueryStruct(ctx, log, db, query, data, dest, false)
|
|
}
|
|
|
|
// NamedQueryStructUsingIn is a helper function for executing queries that return
|
|
// a single value to be unmarshalled into a struct type where field replacement
|
|
// is necessary. Use this if the query has an IN clause.
|
|
func NamedQueryStructUsingIn(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
|
|
return namedQueryStruct(ctx, log, db, query, data, dest, true)
|
|
}
|
|
|
|
func namedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any, withIn bool) (err error) {
|
|
q := queryString(query, data)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
|
}
|
|
}()
|
|
|
|
var rows *sqlx.Rows
|
|
|
|
switch withIn {
|
|
case true:
|
|
rows, err = func() (*sqlx.Rows, error) {
|
|
named, args, err := sqlx.Named(query, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query, args, err := sqlx.In(named, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query = db.Rebind(query)
|
|
return db.QueryxContext(ctx, query, args...)
|
|
}()
|
|
|
|
default:
|
|
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
|
|
}
|
|
|
|
if err != nil {
|
|
var pqErr *pgconn.PgError
|
|
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
|
return ErrUndefinedTable
|
|
}
|
|
return err
|
|
}
|
|
defer func(rows *sqlx.Rows) {
|
|
_ = rows.Close()
|
|
}(rows)
|
|
|
|
if !rows.Next() {
|
|
return ErrDBNotFound
|
|
}
|
|
|
|
if err := rows.StructScan(dest); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// queryString provides a pretty print version of the query and parameters.
|
|
func queryString(query string, args any) string {
|
|
query, params, err := sqlx.Named(query, args)
|
|
if err != nil {
|
|
return err.Error()
|
|
}
|
|
|
|
for _, param := range params {
|
|
var value string
|
|
switch v := param.(type) {
|
|
case string:
|
|
value = fmt.Sprintf("'%s'", v)
|
|
case []byte:
|
|
value = fmt.Sprintf("'%s'", string(v))
|
|
default:
|
|
value = fmt.Sprintf("%v", v)
|
|
}
|
|
query = strings.Replace(query, "?", value, 1)
|
|
}
|
|
|
|
query = strings.ReplaceAll(query, "\t", "")
|
|
query = strings.ReplaceAll(query, "\n", " ")
|
|
|
|
return strings.Trim(query, " ")
|
|
}
|