215 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			215 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package sqldb
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"database/sql"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"management/internal/pkg/config"
 | |
| 
 | |
| 	"github.com/drhin/logger"
 | |
| 	"github.com/jackc/pgx/v5/pgconn"
 | |
| 	_ "github.com/jackc/pgx/v5/stdlib"
 | |
| 	"github.com/jmoiron/sqlx"
 | |
| 	"go.uber.org/zap"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	uniqueViolation = "23505"
 | |
| 	undefinedTable  = "42P01"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	ErrDBNotFound        = sql.ErrNoRows
 | |
| 	ErrDBDuplicatedEntry = errors.New("duplicated entry")
 | |
| 	ErrUndefinedTable    = errors.New("undefined table")
 | |
| )
 | |
| 
 | |
| type Config struct {
 | |
| 	User            string
 | |
| 	Password        string
 | |
| 	Host            string
 | |
| 	Port            int
 | |
| 	Name            string
 | |
| 	MaxIdleConns    int
 | |
| 	MaxOpenConns    int
 | |
| 	ConnMaxLifetime time.Duration
 | |
| 	ConnMaxIdleTime time.Duration
 | |
| }
 | |
| 
 | |
| func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
 | |
| 	dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
 | |
| 		config.DB.Username,
 | |
| 		config.DB.Password,
 | |
| 		config.DB.Host,
 | |
| 		config.DB.Port,
 | |
| 		config.DB.DBName,
 | |
| 	)
 | |
| 
 | |
| 	db, err := sqlx.Open("pgx", dsn)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, fmt.Errorf("sqlx open db: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	if err := db.PingContext(ctx); err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	// 设置最大空闲连接数(默认 2)
 | |
| 	db.SetMaxIdleConns(config.DB.MaxIdleConns)
 | |
| 
 | |
| 	// 设置最大打开连接数(默认 0 无限制)
 | |
| 	db.SetMaxOpenConns(config.DB.MaxOpenConns)
 | |
| 
 | |
| 	// 设置连接最大存活时间
 | |
| 	db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
 | |
| 
 | |
| 	// 设置连接最大空闲时间
 | |
| 	db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
 | |
| 
 | |
| 	cleanup := func() {
 | |
| 		if err := db.Close(); err != nil {
 | |
| 			log.Error("sql db close error", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return db, cleanup, nil
 | |
| }
 | |
| 
 | |
| func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
 | |
| 	defer func() {
 | |
| 		if err != nil {
 | |
| 			switch data.(type) {
 | |
| 			case struct{}:
 | |
| 				log.Error("database.NamedExecContext (data is struct)", err,
 | |
| 					zap.String("query", query),
 | |
| 					zap.Any("ERROR", err))
 | |
| 			default:
 | |
| 				log.Error("database.NamedExecContext", err,
 | |
| 					zap.String("query", query),
 | |
| 					zap.Any("ERROR", err))
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
 | |
| 		var pgError *pgconn.PgError
 | |
| 		if errors.As(err, &pgError) {
 | |
| 			switch pgError.Code {
 | |
| 			case undefinedTable:
 | |
| 				return ErrUndefinedTable
 | |
| 			case uniqueViolation:
 | |
| 				return ErrDBDuplicatedEntry
 | |
| 			}
 | |
| 		}
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
 | |
| 	q := queryString(query, data)
 | |
| 	rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
 | |
| 	if err != nil {
 | |
| 		var pqErr *pgconn.PgError
 | |
| 		if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
 | |
| 			return ErrUndefinedTable
 | |
| 		}
 | |
| 		log.Error("NamedQueryStruct NamedQueryContext error", err,
 | |
| 			zap.String("query", q),
 | |
| 			zap.Any("data", data),
 | |
| 		)
 | |
| 		return err
 | |
| 	}
 | |
| 	defer func(rows *sqlx.Rows) {
 | |
| 		err := rows.Close()
 | |
| 		if err != nil {
 | |
| 			log.Error("rows close error", err)
 | |
| 		}
 | |
| 	}(rows)
 | |
| 
 | |
| 	if !rows.Next() {
 | |
| 		return ErrDBNotFound
 | |
| 	}
 | |
| 
 | |
| 	if err := rows.StructScan(dest); err != nil {
 | |
| 		log.Error("NamedQueryStruct StructScan error", err,
 | |
| 			zap.String("query", q),
 | |
| 			zap.Any("data", data),
 | |
| 		)
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
 | |
| 	q := queryString(query, data)
 | |
| 	rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
 | |
| 	if err != nil {
 | |
| 		var pqErr *pgconn.PgError
 | |
| 		if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
 | |
| 			return ErrUndefinedTable
 | |
| 		}
 | |
| 		log.Error("NamedQueryStruct NamedQueryContext error", err,
 | |
| 			zap.String("query", q),
 | |
| 			zap.Any("data", data),
 | |
| 		)
 | |
| 		return err
 | |
| 	}
 | |
| 	defer func(rows *sqlx.Rows) {
 | |
| 		err := rows.Close()
 | |
| 		if err != nil {
 | |
| 			log.Error("rows close error", err)
 | |
| 		}
 | |
| 	}(rows)
 | |
| 
 | |
| 	var slice []T
 | |
| 	for rows.Next() {
 | |
| 		v := new(T)
 | |
| 		if err := rows.StructScan(v); err != nil {
 | |
| 			log.Error("NamedQuerySlice StructScan error", err,
 | |
| 				zap.String("query", q),
 | |
| 				zap.Any("data", data),
 | |
| 			)
 | |
| 			return err
 | |
| 		}
 | |
| 		slice = append(slice, *v)
 | |
| 	}
 | |
| 	*dest = slice
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func queryString(query string, args any) string {
 | |
| 	query, params, err := sqlx.Named(query, args)
 | |
| 	if err != nil {
 | |
| 		return err.Error()
 | |
| 	}
 | |
| 
 | |
| 	for _, param := range params {
 | |
| 		var value string
 | |
| 		switch v := param.(type) {
 | |
| 		case string:
 | |
| 			value = fmt.Sprintf("'%s'", v)
 | |
| 		case []byte:
 | |
| 			value = fmt.Sprintf("'%s'", string(v))
 | |
| 		default:
 | |
| 			value = fmt.Sprintf("%v", v)
 | |
| 		}
 | |
| 		query = strings.Replace(query, "?", value, 1)
 | |
| 	}
 | |
| 
 | |
| 	query = strings.ReplaceAll(query, "\t", "")
 | |
| 	query = strings.ReplaceAll(query, "\n", " ")
 | |
| 
 | |
| 	return strings.Trim(query, " ")
 | |
| }
 |