// Package sqldb provides support for access the database. package sqldb import ( "context" "database/sql" "errors" "fmt" "strings" "time" "management/internal/pkg/config" "github.com/drhin/logger" "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" "go.uber.org/zap" ) // lib/pq errorCodeNames // https://github.com/lib/pq/blob/master/error.go#L178 const ( uniqueViolation = "23505" undefinedTable = "42P01" ) // Set of error variables for CRUD operations. var ( ErrDBNotFound = sql.ErrNoRows ErrDBDuplicatedEntry = errors.New("duplicated entry") ErrUndefinedTable = errors.New("undefined table") ) // Config is the required properties to use the database. //type Config struct { // User string // Password string // Host string // Name string // Schema string // MaxIdleConns int // MaxOpenConns int // DisableTLS bool //} // Open knows how to open a database connection based on the configuration. //func Open(cfg Config) (*sqlx.DB, error) { // sslMode := "require" // if cfg.DisableTLS { // sslMode = "disable" // } // // q := make(url.Values) // q.Set("sslmode", sslMode) // q.Set("timezone", "utc") // if cfg.Schema != "" { // q.Set("search_path", cfg.Schema) // } // // u := url.URL{ // Scheme: "postgres", // User: url.UserPassword(cfg.User, cfg.Password), // Host: cfg.Host, // Path: cfg.Name, // RawQuery: q.Encode(), // } // // db, err := sqlx.Open("pgx", u.String()) // if err != nil { // return nil, err // } // db.SetMaxIdleConns(cfg.MaxIdleConns) // db.SetMaxOpenConns(cfg.MaxOpenConns) // // return db, nil //} type Config struct { User string Password string Host string Port int Name string MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration ConnMaxIdleTime time.Duration } func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) { dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", config.DB.Username, config.DB.Password, config.DB.Host, config.DB.Port, config.DB.DBName, ) db, err := sqlx.Open("pgx", dsn) if err != nil { return nil, nil, fmt.Errorf("sqlx open db: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() if err := db.PingContext(ctx); err != nil { return nil, nil, err } // 设置最大空闲连接数(默认 2) db.SetMaxIdleConns(config.DB.MaxIdleConns) // 设置最大打开连接数(默认 0 无限制) db.SetMaxOpenConns(config.DB.MaxOpenConns) // 设置连接最大存活时间 db.SetConnMaxLifetime(config.DB.ConnMaxLifetime) // 设置连接最大空闲时间 db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime) cleanup := func() { if err := db.Close(); err != nil { log.Error("sql db close error", err) } } return db, cleanup, nil } // StatusCheck returns nil if it can successfully talk to the database. It // returns a non-nil error otherwise. func StatusCheck(ctx context.Context, db *sqlx.DB) error { // If the user doesn't give us a deadline set 1 second. if _, ok := ctx.Deadline(); !ok { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, time.Second) defer cancel() } for attempts := 1; ; attempts++ { if err := db.Ping(); err == nil { break } time.Sleep(time.Duration(attempts) * 100 * time.Millisecond) if ctx.Err() != nil { return ctx.Err() } } if ctx.Err() != nil { return ctx.Err() } // Run a simple query to determine connectivity. // Running this query forces a round trip through the database. const q = `SELECT TRUE` var tmp bool return db.QueryRowContext(ctx, q).Scan(&tmp) } // ExecContext is a helper function to execute a CUD operation with // logging and tracing. func ExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string) error { return NamedExecContext(ctx, log, db, query, struct{}{}) } // NamedExecContext is a helper function to execute a CUD operation with // logging and tracing where field replacement is necessary. func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) { q := queryString(query, data) defer func() { if err != nil { switch data.(type) { case struct{}: log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 6), zap.Error(err)) default: log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 5), zap.Error(err)) } } }() if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil { var pqerr *pgconn.PgError if errors.As(err, &pqerr) { switch pqerr.Code { case undefinedTable: return ErrUndefinedTable case uniqueViolation: return ErrDBDuplicatedEntry } } return err } return nil } // QuerySlice is a helper function for executing queries that return a // collection of data to be unmarshalled into a slice. func QuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest *[]T) error { return namedQuerySlice(ctx, log, db, query, struct{}{}, dest, false) } // NamedQuerySlice is a helper function for executing queries that return a // collection of data to be unmarshalled into a slice where field replacement is // necessary. func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error { return namedQuerySlice(ctx, log, db, query, data, dest, false) } // NamedQuerySliceUsingIn is a helper function for executing queries that return // a collection of data to be unmarshalled into a slice where field replacement // is necessary. Use this if the query has an IN clause. func NamedQuerySliceUsingIn[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error { return namedQuerySlice(ctx, log, db, query, data, dest, true) } func namedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T, withIn bool) (err error) { q := queryString(query, data) defer func() { if err != nil { log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err)) } }() var rows *sqlx.Rows switch withIn { case true: rows, err = func() (*sqlx.Rows, error) { named, args, err := sqlx.Named(query, data) if err != nil { return nil, err } query, args, err := sqlx.In(named, args...) if err != nil { return nil, err } query = db.Rebind(query) return db.QueryxContext(ctx, query, args...) }() default: rows, err = sqlx.NamedQueryContext(ctx, db, query, data) } if err != nil { var pqErr *pgconn.PgError if errors.As(err, &pqErr) && pqErr.Code == undefinedTable { return ErrUndefinedTable } return err } defer func(rows *sqlx.Rows) { _ = rows.Close() }(rows) var slice []T for rows.Next() { v := new(T) if err := rows.StructScan(v); err != nil { return err } slice = append(slice, *v) } *dest = slice return nil } // QueryStruct is a helper function for executing queries that return a // single value to be unmarshalled into a struct type where field replacement is necessary. func QueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest any) error { return namedQueryStruct(ctx, log, db, query, struct{}{}, dest, false) } // NamedQueryStruct is a helper function for executing queries that return a // single value to be unmarshalled into a struct type where field replacement is necessary. func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error { return namedQueryStruct(ctx, log, db, query, data, dest, false) } // NamedQueryStructUsingIn is a helper function for executing queries that return // a single value to be unmarshalled into a struct type where field replacement // is necessary. Use this if the query has an IN clause. func NamedQueryStructUsingIn(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error { return namedQueryStruct(ctx, log, db, query, data, dest, true) } func namedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any, withIn bool) (err error) { q := queryString(query, data) defer func() { if err != nil { log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err)) } }() var rows *sqlx.Rows switch withIn { case true: rows, err = func() (*sqlx.Rows, error) { named, args, err := sqlx.Named(query, data) if err != nil { return nil, err } query, args, err := sqlx.In(named, args...) if err != nil { return nil, err } query = db.Rebind(query) return db.QueryxContext(ctx, query, args...) }() default: rows, err = sqlx.NamedQueryContext(ctx, db, query, data) } if err != nil { var pqErr *pgconn.PgError if errors.As(err, &pqErr) && pqErr.Code == undefinedTable { return ErrUndefinedTable } return err } defer func(rows *sqlx.Rows) { _ = rows.Close() }(rows) if !rows.Next() { return ErrDBNotFound } if err := rows.StructScan(dest); err != nil { return err } return nil } // queryString provides a pretty print version of the query and parameters. func queryString(query string, args any) string { query, params, err := sqlx.Named(query, args) if err != nil { return err.Error() } for _, param := range params { var value string switch v := param.(type) { case string: value = fmt.Sprintf("'%s'", v) case []byte: value = fmt.Sprintf("'%s'", string(v)) default: value = fmt.Sprintf("%v", v) } query = strings.Replace(query, "?", value, 1) } query = strings.ReplaceAll(query, "\t", "") query = strings.ReplaceAll(query, "\n", " ") return strings.Trim(query, " ") }