This commit is contained in:
2025-10-27 15:24:08 +08:00
parent 4186cd0caf
commit df4c3dd46f
47 changed files with 1757 additions and 306 deletions

View File

@@ -1,15 +1,48 @@
package database
import (
"errors"
// import (
// "database/sql"
// "errors"
"gorm.io/gorm"
)
// "github.com/jackc/pgx/v5/pgconn"
// "gorm.io/gorm"
// )
func IsUniqueViolation(err error) bool {
return errors.Is(err, gorm.ErrDuplicatedKey)
}
// func IsGORMUniqueViolation(err error) bool {
// return errors.Is(err, gorm.ErrDuplicatedKey)
// }
func IsNoRows(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
// func IsGORMNoRows(err error) bool {
// return errors.Is(err, gorm.ErrRecordNotFound)
// }
// // ****************** errors ******************
// const (
// foreignKeyViolation = "23503"
// uniqueViolation = "23505"
// )
// // var ErrUniqueViolation = &pgconn.PgError{
// // Code: UniqueViolation,
// // }
// func ErrorCode(err error) string {
// var pgErr *pgconn.PgError
// if errors.As(err, &pgErr) {
// return pgErr.Code
// }
// return ""
// }
// func IsUniqueViolation(err error) bool {
// var pgErr *pgconn.PgError
// if errors.As(err, &pgErr) {
// return pgErr.Code == uniqueViolation
// }
// return false
// }
// func IsNoRows(err error) bool {
// return errors.Is(err, sql.ErrNoRows)
// }

View File

@@ -9,6 +9,7 @@ import (
"path"
"time"
"github.com/google/uuid"
"github.com/h2non/filetype"
gonanoid "github.com/matoous/go-nanoid/v2"
)
@@ -20,82 +21,94 @@ const (
var ErrUnsupported = errors.New("文件格式不支持")
type FileType int
type Type int
const (
ALL FileType = 0
IMG FileType = 1
ALL Type = 0
IMG Type = 1
)
func UploadFilename(filepath string, t FileType) (string, error) {
fileOpen, err := os.Open(filepath)
if err != nil {
return "", err
}
defer fileOpen.Close()
fileBytes, err := io.ReadAll(fileOpen)
if err != nil {
return "", errors.New("failed to read file")
}
//func UploadFilename(filepath string, t Type) (string, error) {
// fileOpen, err := os.Open(filepath)
// if err != nil {
// return "", err
// }
// defer func(fileOpen *os.File) {
// _ = fileOpen.Close()
// }(fileOpen)
//
// fileBytes, err := io.ReadAll(fileOpen)
// if err != nil {
// return "", errors.New("failed to read file")
// }
//
// if t == IMG {
// // 判断是不是图片
// if !filetype.IsImage(fileBytes) {
// return "", ErrUnsupported
// }
// }
//
// kind, err := filetype.Match(fileBytes)
// if err != nil {
// return "", err
// }
//
// if kind == filetype.Unknown {
// return "", ErrUnsupported
// }
//
// // 使用 filetype 判断类型后已经去读了一些bytes了
// // 要恢复文件读取位置
// _, err = fileOpen.Seek(0, io.SeekStart)
// if err != nil {
// return "", err
// }
//
// dir := GetPath()
// exist, _ := Exists(dir)
// if !exist {
// if err := Mkdir(dir); err != nil {
// return "", err
// }
// }
//
// filename := GenFilename(kind.Extension)
// fullPath := path.Join(dir, filename)
// f, err := os.Create(fullPath)
// if err != nil {
// return "", err
// }
// defer func(f *os.File) {
// _ = f.Close()
// }(f)
//
// _, err = io.Copy(f, fileOpen)
// if err != nil {
// return "", err
// }
//
// return "/" + fullPath, nil
//}
func UploadFile(file *multipart.FileHeader, t Type) (string, error) {
if t == IMG {
// 判断是不是图片
if !filetype.IsImage(fileBytes) {
return "", ErrUnsupported
if file.Size > MaxImageSize {
return "", errors.New("failed to receive images too large")
}
}
kind, err := filetype.Match(fileBytes)
if err != nil {
return "", err
}
if kind == filetype.Unknown {
return "", ErrUnsupported
}
// 使用 filetype 判断类型后已经去读了一些bytes了
// 要恢复文件读取位置
_, err = fileOpen.Seek(0, io.SeekStart)
if err != nil {
return "", err
}
dir := GetPath()
exist, _ := Exists(dir)
if !exist {
if err := Mkdir(dir); err != nil {
return "", err
} else {
if file.Size > MaxFileSize {
return "", errors.New("failed to receive file too large")
}
}
filename := GenFilename(kind.Extension)
path := path.Join(dir, filename)
f, err := os.Create(path)
if err != nil {
return "", err
}
defer f.Close()
_, err = io.Copy(f, fileOpen)
if err != nil {
return "", err
}
return "/" + path, nil
}
func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
if file.Size > MaxFileSize {
return "", errors.New("failed to receive file too large")
}
fileOpen, err := file.Open()
if err != nil {
return "", errors.New("fialed to open file")
return "", errors.New("failed to open file")
}
defer fileOpen.Close()
defer func(fileOpen multipart.File) {
_ = fileOpen.Close()
}(fileOpen)
fileBytes, err := io.ReadAll(fileOpen)
if err != nil {
@@ -134,19 +147,21 @@ func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
}
filename := GenFilename(kind.Extension)
path := path.Join(dir, filename)
f, err := os.Create(path)
fullPath := path.Join(dir, filename)
f, err := os.Create(fullPath)
if err != nil {
return "", err
}
defer f.Close()
defer func(f *os.File) {
_ = f.Close()
}(f)
_, err = io.Copy(f, fileOpen)
if err != nil {
return "", err
}
return "/" + path, nil
return "/" + fullPath, nil
}
func GetPath() string {
@@ -154,6 +169,9 @@ func GetPath() string {
}
func GenFilename(ext string) string {
id, _ := gonanoid.New()
id, err := gonanoid.New()
if err != nil {
return uuid.New().String()
}
return fmt.Sprintf("%s.%s", id, ext)
}

View File

@@ -40,7 +40,7 @@ func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor)
opts := []asynq.Option{
asynq.MaxRetry(10),
asynq.ProcessIn(1 * time.Second),
asynq.Queue(tasks.QueueCritical),
asynq.Queue(tasks.QueueDefault),
}
c, cancel := context.WithTimeout(ctx, 2*time.Second)

View File

@@ -0,0 +1,920 @@
/*
Code taken from https://github.com/lib/pq
Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so, subject
to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
// Package dbarray provides support for database array types.
package dbarray
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
"strings"
)
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
//
// For example:
//
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
//
// var x []sql.NullInt64
// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x))
//
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
// bound is not one (such as `[0:0]={1}') are not supported.
func Array(a any) interface {
driver.Valuer
sql.Scanner
} {
switch a := a.(type) {
case []bool:
return (*Bool)(&a)
case []float64:
return (*Float64)(&a)
case []float32:
return (*Float32)(&a)
case []int64:
return (*Int64)(&a)
case []int32:
return (*Int32)(&a)
case []string:
return (*String)(&a)
case [][]byte:
return (*Bytea)(&a)
case *[]bool:
return (*Bool)(a)
case *[]float64:
return (*Float64)(a)
case *[]float32:
return (*Float32)(a)
case *[]int64:
return (*Int64)(a)
case *[]int32:
return (*Int32)(a)
case *[]string:
return (*String)(a)
case *[][]byte:
return (*Bytea)(a)
}
return Generic{a}
}
// Delimiter may be optionally implemented by driver.Valuer or sql.Scanner
// to override the array delimiter used by Generic.
type Delimiter interface {
// Delimiter returns the delimiter character(s) for this element's type.
Delimiter() string
}
// Bool represents a one-dimensional array of the PostgreSQL boolean type.
type Bool []bool
// Scan implements the sql.Scanner interface.
func (a *Bool) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Bool", src)
}
func (a *Bool) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Bool")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Bool, len(elems))
for i, v := range elems {
if len(v) != 1 {
return fmt.Errorf("database: could not parse boolean array index %d: invalid boolean %q", i, v)
}
switch v[0] {
case 't':
b[i] = true
case 'f':
b[i] = false
default:
return fmt.Errorf("database: could not parse boolean array index %d: invalid boolean %q", i, v)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Bool) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be exactly two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1+2*n)
for i := 0; i < n; i++ {
b[2*i] = ','
if a[i] {
b[1+2*i] = 't'
} else {
b[1+2*i] = 'f'
}
}
b[0] = '{'
b[2*n] = '}'
return string(b), nil
}
return "{}", nil
}
// Bytea represents a one-dimensional array of the PostgreSQL bytea type.
type Bytea [][]byte
// Scan implements the sql.Scanner interface.
func (a *Bytea) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Bytea", src)
}
func (a *Bytea) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Bytea")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Bytea, len(elems))
for i, v := range elems {
b[i], err = parseBytea(v)
if err != nil {
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface. It uses the "hex" format which
// is only supported on PostgreSQL 9.0 or newer.
func (a Bytea) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
size := 1 + 6*n
for _, x := range a {
size += hex.EncodedLen(len(x))
}
b := make([]byte, size)
for i, s := 0, b; i < n; i++ {
o := copy(s, `,"\\x`)
o += hex.Encode(s[o:], a[i])
s[o] = '"'
s = s[o+1:]
}
b[0] = '{'
b[size-1] = '}'
return string(b), nil
}
return "{}", nil
}
// Float64 represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float64 []float64
// Scan implements the sql.Scanner interface.
func (a *Float64) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Float64", src)
}
func (a *Float64) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float64")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float64, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float64) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// Float32 represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float32 []float32
// Scan implements the sql.Scanner interface.
func (a *Float32) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Float32", src)
}
func (a *Float32) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float32")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float32, len(elems))
for i, v := range elems {
var x float64
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
}
b[i] = float32(x)
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float32) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// Generic implements the driver.Valuer and sql.Scanner interfaces for
// an array or slice of any dimension.
type Generic struct{ A any }
func (Generic) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
var assign func([]byte, reflect.Value) error
var del = ","
// TODO calculate the assign function for other types
// TODO repeat this section on the element type of arrays or slices (multidimensional)
{
if reflect.PointerTo(rt).Implements(typeSQLScanner) {
// dest is always addressable because it is an element of a slice.
assign = func(src []byte, dest reflect.Value) (err error) {
ss := dest.Addr().Interface().(sql.Scanner)
if src == nil {
err = ss.Scan(nil)
} else {
err = ss.Scan(src)
}
return
}
goto FoundType
}
assign = func([]byte, reflect.Value) error {
return fmt.Errorf("database: scanning to %s is not implemented; only sql.Scanner", rt)
}
}
FoundType:
if ad, ok := reflect.Zero(rt).Interface().(Delimiter); ok {
del = ad.Delimiter()
}
return rt, assign, del
}
// Scan implements the sql.Scanner interface.
func (a Generic) Scan(src any) error {
dpv := reflect.ValueOf(a.A)
switch {
case dpv.Kind() != reflect.Ptr:
return fmt.Errorf("database: destination %T is not a pointer to array or slice", a.A)
case dpv.IsNil():
return fmt.Errorf("database: destination %T is nil", a.A)
}
dv := dpv.Elem()
switch dv.Kind() {
case reflect.Slice:
case reflect.Array:
default:
return fmt.Errorf("database: destination %T is not a pointer to array or slice", a.A)
}
switch src := src.(type) {
case []byte:
return a.scanBytes(src, dv)
case string:
return a.scanBytes([]byte(src), dv)
case nil:
if dv.Kind() == reflect.Slice {
dv.Set(reflect.Zero(dv.Type()))
return nil
}
}
return fmt.Errorf("database: cannot convert %T to %s", src, dv.Type())
}
func (a Generic) scanBytes(src []byte, dv reflect.Value) error {
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
dims, elems, err := parseArray(src, []byte(del))
if err != nil {
return err
}
// TODO allow multidimensional
if len(dims) > 1 {
return fmt.Errorf("database: scanning from multidimensional ARRAY%s is not implemented",
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
}
// Treat a zero-dimensional array like an array with a single dimension of zero.
if len(dims) == 0 {
dims = append(dims, 0)
}
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
switch rt.Kind() {
case reflect.Slice:
case reflect.Array:
if rt.Len() != dims[i] {
return fmt.Errorf("database: cannot convert ARRAY%s to %s",
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
}
default:
// TODO handle multidimensional
}
}
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
for i, e := range elems {
if err := assign(e, values.Index(i)); err != nil {
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
}
}
// TODO handle multidimensional
switch dv.Kind() {
case reflect.Slice:
dv.Set(values.Slice(0, dims[0]))
case reflect.Array:
for i := 0; i < dims[0]; i++ {
dv.Index(i).Set(values.Index(i))
}
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Generic) Value() (driver.Value, error) {
if a.A == nil {
return nil, nil
}
rv := reflect.ValueOf(a.A)
switch rv.Kind() {
case reflect.Slice:
if rv.IsNil() {
return nil, nil
}
case reflect.Array:
default:
return nil, fmt.Errorf("database: Unable to convert %T to array", a.A)
}
if n := rv.Len(); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 0, 1+2*n)
b, _, err := appendArray(b, rv, n)
return string(b), err
}
return "{}", nil
}
// Int64 represents a one-dimensional array of the PostgreSQL integer types.
type Int64 []int64
// Scan implements the sql.Scanner interface.
func (a *Int64) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Int64", src)
}
func (a *Int64) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int64")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int64, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int64) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, a[0], 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, a[i], 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// Int32 represents a one-dimensional array of the PostgreSQL integer types.
type Int32 []int32
// Scan implements the sql.Scanner interface.
func (a *Int32) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to Int32", src)
}
func (a *Int32) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int32")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int32, len(elems))
for i, v := range elems {
x, err := strconv.ParseInt(string(v), 10, 32)
if err != nil {
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
}
b[i] = int32(x)
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int32) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, int64(a[0]), 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, int64(a[i]), 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// String represents a one-dimensional array of the PostgreSQL character types.
type String []string
// Scan implements the sql.Scanner interface.
func (a *String) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("database: cannot convert %T to String", src)
}
func (a *String) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "String")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(String, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("database: parsing array element index %d: cannot convert nil to string", i)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a String) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, []byte(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(a[i]))
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// appendArray appends rv to the buffer, returning the extended buffer and
// the delimiter used between elements.
//
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
var del string
var err error
b = append(b, '{')
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
return b, del, err
}
for i := 1; i < n; i++ {
b = append(b, del...)
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
return b, del, err
}
}
return append(b, '}'), del, nil
}
// appendArrayElement appends rv to the buffer, returning the extended buffer
// and the delimiter to use before the next element.
//
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
// using driver.DefaultParameterConverter and the resulting []byte or string
// is double-quoted.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
if n := rv.Len(); n > 0 {
return appendArray(b, rv, n)
}
return b, "", nil
}
}
var del = ","
var err error
var iv = rv.Interface()
if ad, ok := iv.(Delimiter); ok {
del = ad.Delimiter()
}
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
return b, del, err
}
switch v := iv.(type) {
case nil:
return append(b, "NULL"...), del, nil
case []byte:
return appendArrayQuotedBytes(b, v), del, nil
case string:
return appendArrayQuotedBytes(b, []byte(v)), del, nil
}
b, err = appendValue(b, iv)
return b, del, err
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func appendValue(b []byte, v driver.Value) ([]byte, error) {
return append(b, encode(nil, v, 0)...), nil
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
if len(src) < 1 || src[0] != '{' {
return nil, nil, fmt.Errorf("database: unable to parse array; expected %q at offset %d", '{', 0)
}
Open:
for i < len(src) {
switch src[i] {
case '{':
depth++
i++
case '}':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
case '{':
if depth == len(dims) {
break Element
}
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) && depth > 0 {
dims[depth-1]++
i += len(del)
goto Element
} else if src[i] == '}' && depth > 0 {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
if src[i] == '}' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
err = fmt.Errorf("database: unable to parse array; expected %q at offset %d", '}', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("database: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("database: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}

View File

@@ -0,0 +1,235 @@
/*
Code taken from https://github.com/lib/pq
Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so, subject
to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package dbarray
import (
"bytes"
"encoding/hex"
"fmt"
"strconv"
"time"
)
const (
infinityTSEnabledAlready = "database: infinity timestamp enabled already"
infinityTSNegativeMustBeSmaller = "database: infinity timestamp: negative value must be smaller (before) than positive"
)
var infinityTSEnabled = false
var infinityTSNegative time.Time
var infinityTSPositive time.Time
type parameterStatus struct {
// server version in the same format as server_version_num, or 0 if unavailable.
serverVersion int
}
// EnableInfinityTS controls the handling of Postgres' "-infinity" and
// "infinity" "timestamp"s.
//
// If EnableInfinityTS is not called, "-infinity" and "infinity" will return
// []byte("-infinity") and []byte("infinity") respectively, and potentially
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
//
// Once EnableInfinityTS has been called, all connections created using this
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
// "timestamp with time zone" and "date" types to the predefined minimum and
// maximum times, respectively. When encoding time.Time values, any time which
// equals or precedes the predefined minimum time will be encoded to
// "-infinity". Any values at or past the maximum time will similarly be
// encoded to "infinity".
//
// If EnableInfinityTS is called with negative >= positive, it will panic.
// Calling EnableInfinityTS after a connection has been established results in
// undefined behavior. If EnableInfinityTS is called more than once, it will
// panic.
func EnableInfinityTS(negative time.Time, positive time.Time) {
if infinityTSEnabled {
panic(infinityTSEnabledAlready)
}
if !negative.Before(positive) {
panic(infinityTSNegativeMustBeSmaller)
}
infinityTSEnabled = true
infinityTSNegative = negative
infinityTSPositive = positive
}
func encode(parameterStatus *parameterStatus, x any, oid int) []byte {
const oidBytea = 17
switch v := x.(type) {
case int64:
return strconv.AppendInt(nil, v, 10)
case float64:
return strconv.AppendFloat(nil, v, 'f', -1, 64)
case []byte:
if oid == oidBytea {
return encodeBytea(parameterStatus.serverVersion, v)
}
return v
case string:
if oid == oidBytea {
return encodeBytea(parameterStatus.serverVersion, []byte(v))
}
return []byte(v)
case bool:
return strconv.AppendBool(nil, v)
case time.Time:
return formatTS(v)
default:
errorf("encode: unknown type for %T", v)
}
panic("not reached")
}
// formatTS formats t into a format postgres understands.
func formatTS(t time.Time) []byte {
if infinityTSEnabled {
// t <= -infinity : ! (t > -infinity)
if !t.After(infinityTSNegative) {
return []byte("-infinity")
}
// t >= infinity : ! (!t < infinity)
if !t.Before(infinityTSPositive) {
return []byte("infinity")
}
}
return formatTimestamp(t)
}
// formatTimestamp formats t into Postgres' text format for timestamps.
func formatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset %= 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
func errorf(s string, args ...any) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytea(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
// bytea_output = escape
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}
func encodeBytea(serverVersion int, v []byte) (result []byte) {
if serverVersion >= 90000 {
// Use the hex format if we know that the server supports it
result = make([]byte, 2+hex.EncodedLen(len(v)))
result[0] = '\\'
result[1] = 'x'
hex.Encode(result[2:], v)
} else {
// .. or resort to "escape"
for _, b := range v {
if b == '\\' {
result = append(result, '\\', '\\')
} else if b < 0x20 || b > 0x7e {
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
} else {
result = append(result, b)
}
}
}
return result
}

View File

@@ -1,3 +1,4 @@
// Package sqldb provides support for access the database.
package sqldb
import (
@@ -17,17 +18,64 @@ import (
"go.uber.org/zap"
)
// lib/pq errorCodeNames
// https://github.com/lib/pq/blob/master/error.go#L178
const (
uniqueViolation = "23505"
undefinedTable = "42P01"
)
// Set of error variables for CRUD operations.
var (
ErrDBNotFound = sql.ErrNoRows
ErrDBDuplicatedEntry = errors.New("duplicated entry")
ErrUndefinedTable = errors.New("undefined table")
)
// Config is the required properties to use the database.
//type Config struct {
// User string
// Password string
// Host string
// Name string
// Schema string
// MaxIdleConns int
// MaxOpenConns int
// DisableTLS bool
//}
// Open knows how to open a database connection based on the configuration.
//func Open(cfg Config) (*sqlx.DB, error) {
// sslMode := "require"
// if cfg.DisableTLS {
// sslMode = "disable"
// }
//
// q := make(url.Values)
// q.Set("sslmode", sslMode)
// q.Set("timezone", "utc")
// if cfg.Schema != "" {
// q.Set("search_path", cfg.Schema)
// }
//
// u := url.URL{
// Scheme: "postgres",
// User: url.UserPassword(cfg.User, cfg.Password),
// Host: cfg.Host,
// Path: cfg.Name,
// RawQuery: q.Encode(),
// }
//
// db, err := sqlx.Open("pgx", u.String())
// if err != nil {
// return nil, err
// }
// db.SetMaxIdleConns(cfg.MaxIdleConns)
// db.SetMaxOpenConns(cfg.MaxOpenConns)
//
// return db, nil
//}
type Config struct {
User string
Password string
@@ -82,26 +130,66 @@ func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error)
return db, cleanup, nil
}
// StatusCheck returns nil if it can successfully talk to the database. It
// returns a non-nil error otherwise.
func StatusCheck(ctx context.Context, db *sqlx.DB) error {
// If the user doesn't give us a deadline set 1 second.
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second)
defer cancel()
}
for attempts := 1; ; attempts++ {
if err := db.Ping(); err == nil {
break
}
time.Sleep(time.Duration(attempts) * 100 * time.Millisecond)
if ctx.Err() != nil {
return ctx.Err()
}
}
if ctx.Err() != nil {
return ctx.Err()
}
// Run a simple query to determine connectivity.
// Running this query forces a round trip through the database.
const q = `SELECT TRUE`
var tmp bool
return db.QueryRowContext(ctx, q).Scan(&tmp)
}
// ExecContext is a helper function to execute a CUD operation with
// logging and tracing.
func ExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string) error {
return NamedExecContext(ctx, log, db, query, struct{}{})
}
// NamedExecContext is a helper function to execute a CUD operation with
// logging and tracing where field replacement is necessary.
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
q := queryString(query, data)
defer func() {
if err != nil {
switch data.(type) {
case struct{}:
log.Error("database.NamedExecContext (data is struct)", err,
zap.String("query", query),
zap.Any("ERROR", err))
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
default:
log.Error("database.NamedExecContext", err,
zap.String("query", query),
zap.Any("ERROR", err))
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 5), zap.Error(err))
}
}
}()
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
var pgError *pgconn.PgError
if errors.As(err, &pgError) {
switch pgError.Code {
var pqerr *pgconn.PgError
if errors.As(err, &pqerr) {
switch pqerr.Code {
case undefinedTable:
return ErrUndefinedTable
case uniqueViolation:
@@ -114,71 +202,73 @@ func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContex
return nil
}
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
}(rows)
if !rows.Next() {
return ErrDBNotFound
}
if err := rows.StructScan(dest); err != nil {
log.Error("NamedQueryStruct StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
return nil
// QuerySlice is a helper function for executing queries that return a
// collection of data to be unmarshalled into a slice.
func QuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, struct{}{}, dest, false)
}
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
// NamedQuerySlice is a helper function for executing queries that return a
// collection of data to be unmarshalled into a slice where field replacement is
// necessary.
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, data, dest, false)
}
// NamedQuerySliceUsingIn is a helper function for executing queries that return
// a collection of data to be unmarshalled into a slice where field replacement
// is necessary. Use this if the query has an IN clause.
func NamedQuerySliceUsingIn[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
return namedQuerySlice(ctx, log, db, query, data, dest, true)
}
func namedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T, withIn bool) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
defer func() {
if err != nil {
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
}
}()
var rows *sqlx.Rows
switch withIn {
case true:
rows, err = func() (*sqlx.Rows, error) {
named, args, err := sqlx.Named(query, data)
if err != nil {
return nil, err
}
query, args, err := sqlx.In(named, args...)
if err != nil {
return nil, err
}
query = db.Rebind(query)
return db.QueryxContext(ctx, query, args...)
}()
default:
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
}
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
_ = rows.Close()
}(rows)
var slice []T
for rows.Next() {
v := new(T)
if err := rows.StructScan(v); err != nil {
log.Error("NamedQuerySlice StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
slice = append(slice, *v)
@@ -188,6 +278,80 @@ func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.Ext
return nil
}
// QueryStruct is a helper function for executing queries that return a
// single value to be unmarshalled into a struct type where field replacement is necessary.
func QueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest any) error {
return namedQueryStruct(ctx, log, db, query, struct{}{}, dest, false)
}
// NamedQueryStruct is a helper function for executing queries that return a
// single value to be unmarshalled into a struct type where field replacement is necessary.
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
return namedQueryStruct(ctx, log, db, query, data, dest, false)
}
// NamedQueryStructUsingIn is a helper function for executing queries that return
// a single value to be unmarshalled into a struct type where field replacement
// is necessary. Use this if the query has an IN clause.
func NamedQueryStructUsingIn(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
return namedQueryStruct(ctx, log, db, query, data, dest, true)
}
func namedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any, withIn bool) (err error) {
q := queryString(query, data)
defer func() {
if err != nil {
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
}
}()
var rows *sqlx.Rows
switch withIn {
case true:
rows, err = func() (*sqlx.Rows, error) {
named, args, err := sqlx.Named(query, data)
if err != nil {
return nil, err
}
query, args, err := sqlx.In(named, args...)
if err != nil {
return nil, err
}
query = db.Rebind(query)
return db.QueryxContext(ctx, query, args...)
}()
default:
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
}
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
return err
}
defer func(rows *sqlx.Rows) {
_ = rows.Close()
}(rows)
if !rows.Next() {
return ErrDBNotFound
}
if err := rows.StructScan(dest); err != nil {
return err
}
return nil
}
// queryString provides a pretty print version of the query and parameters.
func queryString(query string, args any) string {
query, params, err := sqlx.Named(query, args)
if err != nil {