update
This commit is contained in:
@@ -1,15 +1,48 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
// import (
|
||||
// "database/sql"
|
||||
// "errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
// "github.com/jackc/pgx/v5/pgconn"
|
||||
// "gorm.io/gorm"
|
||||
// )
|
||||
|
||||
func IsUniqueViolation(err error) bool {
|
||||
return errors.Is(err, gorm.ErrDuplicatedKey)
|
||||
}
|
||||
// func IsGORMUniqueViolation(err error) bool {
|
||||
// return errors.Is(err, gorm.ErrDuplicatedKey)
|
||||
// }
|
||||
|
||||
func IsNoRows(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
// func IsGORMNoRows(err error) bool {
|
||||
// return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
// }
|
||||
|
||||
// // ****************** errors ******************
|
||||
|
||||
// const (
|
||||
// foreignKeyViolation = "23503"
|
||||
// uniqueViolation = "23505"
|
||||
// )
|
||||
|
||||
// // var ErrUniqueViolation = &pgconn.PgError{
|
||||
// // Code: UniqueViolation,
|
||||
// // }
|
||||
|
||||
// func ErrorCode(err error) string {
|
||||
// var pgErr *pgconn.PgError
|
||||
// if errors.As(err, &pgErr) {
|
||||
// return pgErr.Code
|
||||
// }
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// func IsUniqueViolation(err error) bool {
|
||||
// var pgErr *pgconn.PgError
|
||||
// if errors.As(err, &pgErr) {
|
||||
// return pgErr.Code == uniqueViolation
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
|
||||
// func IsNoRows(err error) bool {
|
||||
// return errors.Is(err, sql.ErrNoRows)
|
||||
// }
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/h2non/filetype"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
)
|
||||
@@ -20,82 +21,94 @@ const (
|
||||
|
||||
var ErrUnsupported = errors.New("文件格式不支持")
|
||||
|
||||
type FileType int
|
||||
type Type int
|
||||
|
||||
const (
|
||||
ALL FileType = 0
|
||||
IMG FileType = 1
|
||||
ALL Type = 0
|
||||
IMG Type = 1
|
||||
)
|
||||
|
||||
func UploadFilename(filepath string, t FileType) (string, error) {
|
||||
fileOpen, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer fileOpen.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(fileOpen)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to read file")
|
||||
}
|
||||
//func UploadFilename(filepath string, t Type) (string, error) {
|
||||
// fileOpen, err := os.Open(filepath)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// defer func(fileOpen *os.File) {
|
||||
// _ = fileOpen.Close()
|
||||
// }(fileOpen)
|
||||
//
|
||||
// fileBytes, err := io.ReadAll(fileOpen)
|
||||
// if err != nil {
|
||||
// return "", errors.New("failed to read file")
|
||||
// }
|
||||
//
|
||||
// if t == IMG {
|
||||
// // 判断是不是图片
|
||||
// if !filetype.IsImage(fileBytes) {
|
||||
// return "", ErrUnsupported
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// kind, err := filetype.Match(fileBytes)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
//
|
||||
// if kind == filetype.Unknown {
|
||||
// return "", ErrUnsupported
|
||||
// }
|
||||
//
|
||||
// // 使用 filetype 判断类型后已经去读了一些bytes了
|
||||
// // 要恢复文件读取位置
|
||||
// _, err = fileOpen.Seek(0, io.SeekStart)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
//
|
||||
// dir := GetPath()
|
||||
// exist, _ := Exists(dir)
|
||||
// if !exist {
|
||||
// if err := Mkdir(dir); err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// filename := GenFilename(kind.Extension)
|
||||
// fullPath := path.Join(dir, filename)
|
||||
// f, err := os.Create(fullPath)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// defer func(f *os.File) {
|
||||
// _ = f.Close()
|
||||
// }(f)
|
||||
//
|
||||
// _, err = io.Copy(f, fileOpen)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
//
|
||||
// return "/" + fullPath, nil
|
||||
//}
|
||||
|
||||
func UploadFile(file *multipart.FileHeader, t Type) (string, error) {
|
||||
if t == IMG {
|
||||
// 判断是不是图片
|
||||
if !filetype.IsImage(fileBytes) {
|
||||
return "", ErrUnsupported
|
||||
if file.Size > MaxImageSize {
|
||||
return "", errors.New("failed to receive images too large")
|
||||
}
|
||||
}
|
||||
|
||||
kind, err := filetype.Match(fileBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if kind == filetype.Unknown {
|
||||
return "", ErrUnsupported
|
||||
}
|
||||
|
||||
// 使用 filetype 判断类型后已经去读了一些bytes了
|
||||
// 要恢复文件读取位置
|
||||
_, err = fileOpen.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dir := GetPath()
|
||||
exist, _ := Exists(dir)
|
||||
if !exist {
|
||||
if err := Mkdir(dir); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
if file.Size > MaxFileSize {
|
||||
return "", errors.New("failed to receive file too large")
|
||||
}
|
||||
}
|
||||
|
||||
filename := GenFilename(kind.Extension)
|
||||
path := path.Join(dir, filename)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, fileOpen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "/" + path, nil
|
||||
}
|
||||
|
||||
func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
|
||||
if file.Size > MaxFileSize {
|
||||
return "", errors.New("failed to receive file too large")
|
||||
}
|
||||
|
||||
fileOpen, err := file.Open()
|
||||
if err != nil {
|
||||
return "", errors.New("fialed to open file")
|
||||
return "", errors.New("failed to open file")
|
||||
}
|
||||
defer fileOpen.Close()
|
||||
defer func(fileOpen multipart.File) {
|
||||
_ = fileOpen.Close()
|
||||
}(fileOpen)
|
||||
|
||||
fileBytes, err := io.ReadAll(fileOpen)
|
||||
if err != nil {
|
||||
@@ -134,19 +147,21 @@ func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
|
||||
}
|
||||
|
||||
filename := GenFilename(kind.Extension)
|
||||
path := path.Join(dir, filename)
|
||||
f, err := os.Create(path)
|
||||
fullPath := path.Join(dir, filename)
|
||||
f, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func(f *os.File) {
|
||||
_ = f.Close()
|
||||
}(f)
|
||||
|
||||
_, err = io.Copy(f, fileOpen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "/" + path, nil
|
||||
return "/" + fullPath, nil
|
||||
}
|
||||
|
||||
func GetPath() string {
|
||||
@@ -154,6 +169,9 @@ func GetPath() string {
|
||||
}
|
||||
|
||||
func GenFilename(ext string) string {
|
||||
id, _ := gonanoid.New()
|
||||
id, err := gonanoid.New()
|
||||
if err != nil {
|
||||
return uuid.New().String()
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", id, ext)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor)
|
||||
opts := []asynq.Option{
|
||||
asynq.MaxRetry(10),
|
||||
asynq.ProcessIn(1 * time.Second),
|
||||
asynq.Queue(tasks.QueueCritical),
|
||||
asynq.Queue(tasks.QueueDefault),
|
||||
}
|
||||
|
||||
c, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
|
||||
920
internal/pkg/sqldb/dbarray/dbarray.go
Normal file
920
internal/pkg/sqldb/dbarray/dbarray.go
Normal file
@@ -0,0 +1,920 @@
|
||||
/*
|
||||
Code taken from https://github.com/lib/pq
|
||||
|
||||
Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to use,
|
||||
copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
|
||||
Software, and to permit persons to whom the Software is furnished to do so, subject
|
||||
to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
// Package dbarray provides support for database array types.
|
||||
package dbarray
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var typeByteSlice = reflect.TypeOf([]byte{})
|
||||
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
|
||||
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
|
||||
// slice of any dimension.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
|
||||
//
|
||||
// var x []sql.NullInt64
|
||||
// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x))
|
||||
//
|
||||
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
|
||||
// bound is not one (such as `[0:0]={1}') are not supported.
|
||||
func Array(a any) interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
} {
|
||||
switch a := a.(type) {
|
||||
case []bool:
|
||||
return (*Bool)(&a)
|
||||
case []float64:
|
||||
return (*Float64)(&a)
|
||||
case []float32:
|
||||
return (*Float32)(&a)
|
||||
case []int64:
|
||||
return (*Int64)(&a)
|
||||
case []int32:
|
||||
return (*Int32)(&a)
|
||||
case []string:
|
||||
return (*String)(&a)
|
||||
case [][]byte:
|
||||
return (*Bytea)(&a)
|
||||
|
||||
case *[]bool:
|
||||
return (*Bool)(a)
|
||||
case *[]float64:
|
||||
return (*Float64)(a)
|
||||
case *[]float32:
|
||||
return (*Float32)(a)
|
||||
case *[]int64:
|
||||
return (*Int64)(a)
|
||||
case *[]int32:
|
||||
return (*Int32)(a)
|
||||
case *[]string:
|
||||
return (*String)(a)
|
||||
case *[][]byte:
|
||||
return (*Bytea)(a)
|
||||
}
|
||||
|
||||
return Generic{a}
|
||||
}
|
||||
|
||||
// Delimiter may be optionally implemented by driver.Valuer or sql.Scanner
|
||||
// to override the array delimiter used by Generic.
|
||||
type Delimiter interface {
|
||||
// Delimiter returns the delimiter character(s) for this element's type.
|
||||
Delimiter() string
|
||||
}
|
||||
|
||||
// Bool represents a one-dimensional array of the PostgreSQL boolean type.
|
||||
type Bool []bool
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Bool) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Bool", src)
|
||||
}
|
||||
|
||||
func (a *Bool) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Bool")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Bool, len(elems))
|
||||
for i, v := range elems {
|
||||
if len(v) != 1 {
|
||||
return fmt.Errorf("database: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
switch v[0] {
|
||||
case 't':
|
||||
b[i] = true
|
||||
case 'f':
|
||||
b[i] = false
|
||||
default:
|
||||
return fmt.Errorf("database: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Bool) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be exactly two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1+2*n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
b[2*i] = ','
|
||||
if a[i] {
|
||||
b[1+2*i] = 't'
|
||||
} else {
|
||||
b[1+2*i] = 'f'
|
||||
}
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[2*n] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Bytea represents a one-dimensional array of the PostgreSQL bytea type.
|
||||
type Bytea [][]byte
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Bytea) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Bytea", src)
|
||||
}
|
||||
|
||||
func (a *Bytea) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Bytea")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Bytea, len(elems))
|
||||
for i, v := range elems {
|
||||
b[i], err = parseBytea(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface. It uses the "hex" format which
|
||||
// is only supported on PostgreSQL 9.0 or newer.
|
||||
func (a Bytea) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
|
||||
size := 1 + 6*n
|
||||
for _, x := range a {
|
||||
size += hex.EncodedLen(len(x))
|
||||
}
|
||||
|
||||
b := make([]byte, size)
|
||||
|
||||
for i, s := 0, b; i < n; i++ {
|
||||
o := copy(s, `,"\\x`)
|
||||
o += hex.Encode(s[o:], a[i])
|
||||
s[o] = '"'
|
||||
s = s[o+1:]
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[size-1] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Float64 represents a one-dimensional array of the PostgreSQL double
|
||||
// precision type.
|
||||
type Float64 []float64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Float64) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Float64", src)
|
||||
}
|
||||
|
||||
func (a *Float64) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Float64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Float64, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Float64) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Float32 represents a one-dimensional array of the PostgreSQL double
|
||||
// precision type.
|
||||
type Float32 []float32
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Float32) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Float32", src)
|
||||
}
|
||||
|
||||
func (a *Float32) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Float32")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Float32, len(elems))
|
||||
for i, v := range elems {
|
||||
var x float64
|
||||
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
b[i] = float32(x)
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Float32) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Generic implements the driver.Valuer and sql.Scanner interfaces for
|
||||
// an array or slice of any dimension.
|
||||
type Generic struct{ A any }
|
||||
|
||||
func (Generic) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
|
||||
var assign func([]byte, reflect.Value) error
|
||||
var del = ","
|
||||
|
||||
// TODO calculate the assign function for other types
|
||||
// TODO repeat this section on the element type of arrays or slices (multidimensional)
|
||||
{
|
||||
if reflect.PointerTo(rt).Implements(typeSQLScanner) {
|
||||
// dest is always addressable because it is an element of a slice.
|
||||
assign = func(src []byte, dest reflect.Value) (err error) {
|
||||
ss := dest.Addr().Interface().(sql.Scanner)
|
||||
if src == nil {
|
||||
err = ss.Scan(nil)
|
||||
} else {
|
||||
err = ss.Scan(src)
|
||||
}
|
||||
return
|
||||
}
|
||||
goto FoundType
|
||||
}
|
||||
|
||||
assign = func([]byte, reflect.Value) error {
|
||||
return fmt.Errorf("database: scanning to %s is not implemented; only sql.Scanner", rt)
|
||||
}
|
||||
}
|
||||
|
||||
FoundType:
|
||||
|
||||
if ad, ok := reflect.Zero(rt).Interface().(Delimiter); ok {
|
||||
del = ad.Delimiter()
|
||||
}
|
||||
|
||||
return rt, assign, del
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a Generic) Scan(src any) error {
|
||||
dpv := reflect.ValueOf(a.A)
|
||||
switch {
|
||||
case dpv.Kind() != reflect.Ptr:
|
||||
return fmt.Errorf("database: destination %T is not a pointer to array or slice", a.A)
|
||||
case dpv.IsNil():
|
||||
return fmt.Errorf("database: destination %T is nil", a.A)
|
||||
}
|
||||
|
||||
dv := dpv.Elem()
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
default:
|
||||
return fmt.Errorf("database: destination %T is not a pointer to array or slice", a.A)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src, dv)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src), dv)
|
||||
case nil:
|
||||
if dv.Kind() == reflect.Slice {
|
||||
dv.Set(reflect.Zero(dv.Type()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to %s", src, dv.Type())
|
||||
}
|
||||
|
||||
func (a Generic) scanBytes(src []byte, dv reflect.Value) error {
|
||||
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
|
||||
dims, elems, err := parseArray(src, []byte(del))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO allow multidimensional
|
||||
|
||||
if len(dims) > 1 {
|
||||
return fmt.Errorf("database: scanning from multidimensional ARRAY%s is not implemented",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
|
||||
}
|
||||
|
||||
// Treat a zero-dimensional array like an array with a single dimension of zero.
|
||||
if len(dims) == 0 {
|
||||
dims = append(dims, 0)
|
||||
}
|
||||
|
||||
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
|
||||
switch rt.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
if rt.Len() != dims[i] {
|
||||
return fmt.Errorf("database: cannot convert ARRAY%s to %s",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
|
||||
}
|
||||
default:
|
||||
// TODO handle multidimensional
|
||||
}
|
||||
}
|
||||
|
||||
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
|
||||
for i, e := range elems {
|
||||
if err := assign(e, values.Index(i)); err != nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO handle multidimensional
|
||||
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
dv.Set(values.Slice(0, dims[0]))
|
||||
case reflect.Array:
|
||||
for i := 0; i < dims[0]; i++ {
|
||||
dv.Index(i).Set(values.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Generic) Value() (driver.Value, error) {
|
||||
if a.A == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(a.A)
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice:
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
case reflect.Array:
|
||||
default:
|
||||
return nil, fmt.Errorf("database: Unable to convert %T to array", a.A)
|
||||
}
|
||||
|
||||
if n := rv.Len(); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 0, 1+2*n)
|
||||
|
||||
b, _, err := appendArray(b, rv, n)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Int64 represents a one-dimensional array of the PostgreSQL integer types.
|
||||
type Int64 []int64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Int64) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Int64", src)
|
||||
}
|
||||
|
||||
func (a *Int64) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Int64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Int64, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Int64) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendInt(b, a[0], 10)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendInt(b, a[i], 10)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Int32 represents a one-dimensional array of the PostgreSQL integer types.
|
||||
type Int32 []int32
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Int32) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to Int32", src)
|
||||
}
|
||||
|
||||
func (a *Int32) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Int32")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Int32, len(elems))
|
||||
for i, v := range elems {
|
||||
x, err := strconv.ParseInt(string(v), 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
b[i] = int32(x)
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Int32) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendInt(b, int64(a[0]), 10)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendInt(b, int64(a[i]), 10)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// String represents a one-dimensional array of the PostgreSQL character types.
|
||||
type String []string
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *String) Scan(src any) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("database: cannot convert %T to String", src)
|
||||
}
|
||||
|
||||
func (a *String) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "String")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(String, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i] = string(v); v == nil {
|
||||
return fmt.Errorf("database: parsing array element index %d: cannot convert nil to string", i)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a String) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+3*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = appendArrayQuotedBytes(b, []byte(a[0]))
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = appendArrayQuotedBytes(b, []byte(a[i]))
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// appendArray appends rv to the buffer, returning the extended buffer and
|
||||
// the delimiter used between elements.
|
||||
//
|
||||
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
|
||||
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
|
||||
var del string
|
||||
var err error
|
||||
|
||||
b = append(b, '{')
|
||||
|
||||
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, del...)
|
||||
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
}
|
||||
|
||||
return append(b, '}'), del, nil
|
||||
}
|
||||
|
||||
// appendArrayElement appends rv to the buffer, returning the extended buffer
|
||||
// and the delimiter to use before the next element.
|
||||
//
|
||||
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
|
||||
// using driver.DefaultParameterConverter and the resulting []byte or string
|
||||
// is double-quoted.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
|
||||
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
|
||||
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
|
||||
if n := rv.Len(); n > 0 {
|
||||
return appendArray(b, rv, n)
|
||||
}
|
||||
|
||||
return b, "", nil
|
||||
}
|
||||
}
|
||||
|
||||
var del = ","
|
||||
var err error
|
||||
var iv = rv.Interface()
|
||||
|
||||
if ad, ok := iv.(Delimiter); ok {
|
||||
del = ad.Delimiter()
|
||||
}
|
||||
|
||||
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
switch v := iv.(type) {
|
||||
case nil:
|
||||
return append(b, "NULL"...), del, nil
|
||||
case []byte:
|
||||
return appendArrayQuotedBytes(b, v), del, nil
|
||||
case string:
|
||||
return appendArrayQuotedBytes(b, []byte(v)), del, nil
|
||||
}
|
||||
|
||||
b, err = appendValue(b, iv)
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
func appendArrayQuotedBytes(b, v []byte) []byte {
|
||||
b = append(b, '"')
|
||||
for {
|
||||
i := bytes.IndexAny(v, `"\`)
|
||||
if i < 0 {
|
||||
b = append(b, v...)
|
||||
break
|
||||
}
|
||||
if i > 0 {
|
||||
b = append(b, v[:i]...)
|
||||
}
|
||||
b = append(b, '\\', v[i])
|
||||
v = v[i+1:]
|
||||
}
|
||||
return append(b, '"')
|
||||
}
|
||||
|
||||
func appendValue(b []byte, v driver.Value) ([]byte, error) {
|
||||
return append(b, encode(nil, v, 0)...), nil
|
||||
}
|
||||
|
||||
// parseArray extracts the dimensions and elements of an array represented in
|
||||
// text format. Only representations emitted by the backend are supported.
|
||||
// Notably, whitespace around brackets and delimiters is significant, and NULL
|
||||
// is case-sensitive.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
|
||||
var depth, i int
|
||||
|
||||
if len(src) < 1 || src[0] != '{' {
|
||||
return nil, nil, fmt.Errorf("database: unable to parse array; expected %q at offset %d", '{', 0)
|
||||
}
|
||||
|
||||
Open:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
depth++
|
||||
i++
|
||||
case '}':
|
||||
elems = make([][]byte, 0)
|
||||
goto Close
|
||||
default:
|
||||
break Open
|
||||
}
|
||||
}
|
||||
dims = make([]int, i)
|
||||
|
||||
Element:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
if depth == len(dims) {
|
||||
break Element
|
||||
}
|
||||
depth++
|
||||
dims[depth-1] = 0
|
||||
i++
|
||||
case '"':
|
||||
var elem = []byte{}
|
||||
var escape bool
|
||||
for i++; i < len(src); i++ {
|
||||
if escape {
|
||||
elem = append(elem, src[i])
|
||||
escape = false
|
||||
} else {
|
||||
switch src[i] {
|
||||
default:
|
||||
elem = append(elem, src[i])
|
||||
case '\\':
|
||||
escape = true
|
||||
case '"':
|
||||
elems = append(elems, elem)
|
||||
i++
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
for start := i; i < len(src); i++ {
|
||||
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
|
||||
elem := src[start:i]
|
||||
if len(elem) == 0 {
|
||||
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
if bytes.Equal(elem, []byte("NULL")) {
|
||||
elem = nil
|
||||
}
|
||||
elems = append(elems, elem)
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i < len(src) {
|
||||
if bytes.HasPrefix(src[i:], del) && depth > 0 {
|
||||
dims[depth-1]++
|
||||
i += len(del)
|
||||
goto Element
|
||||
} else if src[i] == '}' && depth > 0 {
|
||||
dims[depth-1]++
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
|
||||
Close:
|
||||
for i < len(src) {
|
||||
if src[i] == '}' && depth > 0 {
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("database: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
if depth > 0 {
|
||||
err = fmt.Errorf("database: unable to parse array; expected %q at offset %d", '}', i)
|
||||
}
|
||||
if err == nil {
|
||||
for _, d := range dims {
|
||||
if (len(elems) % d) != 0 {
|
||||
err = fmt.Errorf("database: multidimensional arrays must have elements with matching dimensions")
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
|
||||
dims, elems, err := parseArray(src, del)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) > 1 {
|
||||
return nil, fmt.Errorf("database: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
|
||||
}
|
||||
return elems, err
|
||||
}
|
||||
235
internal/pkg/sqldb/dbarray/encode.go
Normal file
235
internal/pkg/sqldb/dbarray/encode.go
Normal file
@@ -0,0 +1,235 @@
|
||||
/*
|
||||
Code taken from https://github.com/lib/pq
|
||||
|
||||
Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to use,
|
||||
copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
|
||||
Software, and to permit persons to whom the Software is furnished to do so, subject
|
||||
to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package dbarray
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
infinityTSEnabledAlready = "database: infinity timestamp enabled already"
|
||||
infinityTSNegativeMustBeSmaller = "database: infinity timestamp: negative value must be smaller (before) than positive"
|
||||
)
|
||||
|
||||
var infinityTSEnabled = false
|
||||
var infinityTSNegative time.Time
|
||||
var infinityTSPositive time.Time
|
||||
|
||||
type parameterStatus struct {
|
||||
// server version in the same format as server_version_num, or 0 if unavailable.
|
||||
serverVersion int
|
||||
}
|
||||
|
||||
// EnableInfinityTS controls the handling of Postgres' "-infinity" and
|
||||
// "infinity" "timestamp"s.
|
||||
//
|
||||
// If EnableInfinityTS is not called, "-infinity" and "infinity" will return
|
||||
// []byte("-infinity") and []byte("infinity") respectively, and potentially
|
||||
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
|
||||
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
|
||||
//
|
||||
// Once EnableInfinityTS has been called, all connections created using this
|
||||
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
|
||||
// "timestamp with time zone" and "date" types to the predefined minimum and
|
||||
// maximum times, respectively. When encoding time.Time values, any time which
|
||||
// equals or precedes the predefined minimum time will be encoded to
|
||||
// "-infinity". Any values at or past the maximum time will similarly be
|
||||
// encoded to "infinity".
|
||||
//
|
||||
// If EnableInfinityTS is called with negative >= positive, it will panic.
|
||||
// Calling EnableInfinityTS after a connection has been established results in
|
||||
// undefined behavior. If EnableInfinityTS is called more than once, it will
|
||||
// panic.
|
||||
func EnableInfinityTS(negative time.Time, positive time.Time) {
|
||||
if infinityTSEnabled {
|
||||
panic(infinityTSEnabledAlready)
|
||||
}
|
||||
if !negative.Before(positive) {
|
||||
panic(infinityTSNegativeMustBeSmaller)
|
||||
}
|
||||
infinityTSEnabled = true
|
||||
infinityTSNegative = negative
|
||||
infinityTSPositive = positive
|
||||
}
|
||||
|
||||
func encode(parameterStatus *parameterStatus, x any, oid int) []byte {
|
||||
const oidBytea = 17
|
||||
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
return strconv.AppendFloat(nil, v, 'f', -1, 64)
|
||||
case []byte:
|
||||
if oid == oidBytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, v)
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if oid == oidBytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, []byte(v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return strconv.AppendBool(nil, v)
|
||||
case time.Time:
|
||||
return formatTS(v)
|
||||
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
// formatTS formats t into a format postgres understands.
|
||||
func formatTS(t time.Time) []byte {
|
||||
if infinityTSEnabled {
|
||||
// t <= -infinity : ! (t > -infinity)
|
||||
if !t.After(infinityTSNegative) {
|
||||
return []byte("-infinity")
|
||||
}
|
||||
// t >= infinity : ! (!t < infinity)
|
||||
if !t.Before(infinityTSPositive) {
|
||||
return []byte("infinity")
|
||||
}
|
||||
}
|
||||
return formatTimestamp(t)
|
||||
}
|
||||
|
||||
// formatTimestamp formats t into Postgres' text format for timestamps.
|
||||
func formatTimestamp(t time.Time) []byte {
|
||||
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
|
||||
// minus sign preferred by Go.
|
||||
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
|
||||
bc := false
|
||||
if t.Year() <= 0 {
|
||||
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
|
||||
t = t.AddDate((-t.Year())*2+1, 0, 0)
|
||||
bc = true
|
||||
}
|
||||
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
|
||||
|
||||
_, offset := t.Zone()
|
||||
offset %= 60
|
||||
if offset != 0 {
|
||||
// RFC3339Nano already printed the minus sign
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
b = append(b, ':')
|
||||
if offset < 10 {
|
||||
b = append(b, '0')
|
||||
}
|
||||
b = strconv.AppendInt(b, int64(offset), 10)
|
||||
}
|
||||
|
||||
if bc {
|
||||
b = append(b, " BC"...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func errorf(s string, args ...any) {
|
||||
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
|
||||
}
|
||||
|
||||
// Parse a bytea value received from the server. Both "hex" and the legacy
|
||||
// "escape" format are supported.
|
||||
func parseBytea(s []byte) (result []byte, err error) {
|
||||
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
|
||||
// bytea_output = hex
|
||||
s = s[2:] // trim off leading "\\x"
|
||||
result = make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(result, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// bytea_output = escape
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\\' {
|
||||
// escaped '\\'
|
||||
if len(s) >= 2 && s[1] == '\\' {
|
||||
result = append(result, '\\')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
|
||||
// '\\' followed by an octal number
|
||||
if len(s) < 4 {
|
||||
return nil, fmt.Errorf("invalid bytea sequence %v", s)
|
||||
}
|
||||
r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
|
||||
}
|
||||
result = append(result, byte(r))
|
||||
s = s[4:]
|
||||
} else {
|
||||
// We hit an unescaped, raw byte. Try to read in as many as
|
||||
// possible in one go.
|
||||
i := bytes.IndexByte(s, '\\')
|
||||
if i == -1 {
|
||||
result = append(result, s...)
|
||||
break
|
||||
}
|
||||
result = append(result, s[:i]...)
|
||||
s = s[i:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func encodeBytea(serverVersion int, v []byte) (result []byte) {
|
||||
if serverVersion >= 90000 {
|
||||
// Use the hex format if we know that the server supports it
|
||||
result = make([]byte, 2+hex.EncodedLen(len(v)))
|
||||
result[0] = '\\'
|
||||
result[1] = 'x'
|
||||
hex.Encode(result[2:], v)
|
||||
} else {
|
||||
// .. or resort to "escape"
|
||||
for _, b := range v {
|
||||
if b == '\\' {
|
||||
result = append(result, '\\', '\\')
|
||||
} else if b < 0x20 || b > 0x7e {
|
||||
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
|
||||
} else {
|
||||
result = append(result, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package sqldb provides support for access the database.
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
@@ -17,17 +18,64 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// lib/pq errorCodeNames
|
||||
// https://github.com/lib/pq/blob/master/error.go#L178
|
||||
const (
|
||||
uniqueViolation = "23505"
|
||||
undefinedTable = "42P01"
|
||||
)
|
||||
|
||||
// Set of error variables for CRUD operations.
|
||||
var (
|
||||
ErrDBNotFound = sql.ErrNoRows
|
||||
ErrDBDuplicatedEntry = errors.New("duplicated entry")
|
||||
ErrUndefinedTable = errors.New("undefined table")
|
||||
)
|
||||
|
||||
// Config is the required properties to use the database.
|
||||
//type Config struct {
|
||||
// User string
|
||||
// Password string
|
||||
// Host string
|
||||
// Name string
|
||||
// Schema string
|
||||
// MaxIdleConns int
|
||||
// MaxOpenConns int
|
||||
// DisableTLS bool
|
||||
//}
|
||||
|
||||
// Open knows how to open a database connection based on the configuration.
|
||||
//func Open(cfg Config) (*sqlx.DB, error) {
|
||||
// sslMode := "require"
|
||||
// if cfg.DisableTLS {
|
||||
// sslMode = "disable"
|
||||
// }
|
||||
//
|
||||
// q := make(url.Values)
|
||||
// q.Set("sslmode", sslMode)
|
||||
// q.Set("timezone", "utc")
|
||||
// if cfg.Schema != "" {
|
||||
// q.Set("search_path", cfg.Schema)
|
||||
// }
|
||||
//
|
||||
// u := url.URL{
|
||||
// Scheme: "postgres",
|
||||
// User: url.UserPassword(cfg.User, cfg.Password),
|
||||
// Host: cfg.Host,
|
||||
// Path: cfg.Name,
|
||||
// RawQuery: q.Encode(),
|
||||
// }
|
||||
//
|
||||
// db, err := sqlx.Open("pgx", u.String())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
// db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
//
|
||||
// return db, nil
|
||||
//}
|
||||
|
||||
type Config struct {
|
||||
User string
|
||||
Password string
|
||||
@@ -82,26 +130,66 @@ func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error)
|
||||
return db, cleanup, nil
|
||||
}
|
||||
|
||||
// StatusCheck returns nil if it can successfully talk to the database. It
|
||||
// returns a non-nil error otherwise.
|
||||
func StatusCheck(ctx context.Context, db *sqlx.DB) error {
|
||||
|
||||
// If the user doesn't give us a deadline set 1 second.
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
for attempts := 1; ; attempts++ {
|
||||
if err := db.Ping(); err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(attempts) * 100 * time.Millisecond)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Run a simple query to determine connectivity.
|
||||
// Running this query forces a round trip through the database.
|
||||
const q = `SELECT TRUE`
|
||||
var tmp bool
|
||||
return db.QueryRowContext(ctx, q).Scan(&tmp)
|
||||
}
|
||||
|
||||
// ExecContext is a helper function to execute a CUD operation with
|
||||
// logging and tracing.
|
||||
func ExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string) error {
|
||||
return NamedExecContext(ctx, log, db, query, struct{}{})
|
||||
}
|
||||
|
||||
// NamedExecContext is a helper function to execute a CUD operation with
|
||||
// logging and tracing where field replacement is necessary.
|
||||
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
|
||||
q := queryString(query, data)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
switch data.(type) {
|
||||
case struct{}:
|
||||
log.Error("database.NamedExecContext (data is struct)", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
||||
default:
|
||||
log.Error("database.NamedExecContext", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
log.Info("database.NamedExecContext", zap.String("query", q), zap.Int("type", 5), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
|
||||
var pgError *pgconn.PgError
|
||||
if errors.As(err, &pgError) {
|
||||
switch pgError.Code {
|
||||
var pqerr *pgconn.PgError
|
||||
if errors.As(err, &pqerr) {
|
||||
switch pqerr.Code {
|
||||
case undefinedTable:
|
||||
return ErrUndefinedTable
|
||||
case uniqueViolation:
|
||||
@@ -114,71 +202,73 @@ func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContex
|
||||
return nil
|
||||
}
|
||||
|
||||
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
}(rows)
|
||||
|
||||
if !rows.Next() {
|
||||
return ErrDBNotFound
|
||||
}
|
||||
|
||||
if err := rows.StructScan(dest); err != nil {
|
||||
log.Error("NamedQueryStruct StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// QuerySlice is a helper function for executing queries that return a
|
||||
// collection of data to be unmarshalled into a slice.
|
||||
func QuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest *[]T) error {
|
||||
return namedQuerySlice(ctx, log, db, query, struct{}{}, dest, false)
|
||||
}
|
||||
|
||||
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
|
||||
// NamedQuerySlice is a helper function for executing queries that return a
|
||||
// collection of data to be unmarshalled into a slice where field replacement is
|
||||
// necessary.
|
||||
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
|
||||
return namedQuerySlice(ctx, log, db, query, data, dest, false)
|
||||
}
|
||||
|
||||
// NamedQuerySliceUsingIn is a helper function for executing queries that return
|
||||
// a collection of data to be unmarshalled into a slice where field replacement
|
||||
// is necessary. Use this if the query has an IN clause.
|
||||
func NamedQuerySliceUsingIn[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) error {
|
||||
return namedQuerySlice(ctx, log, db, query, data, dest, true)
|
||||
}
|
||||
|
||||
func namedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T, withIn bool) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
var rows *sqlx.Rows
|
||||
|
||||
switch withIn {
|
||||
case true:
|
||||
rows, err = func() (*sqlx.Rows, error) {
|
||||
named, args, err := sqlx.Named(query, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, args, err := sqlx.In(named, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query = db.Rebind(query)
|
||||
return db.QueryxContext(ctx, query, args...)
|
||||
}()
|
||||
|
||||
default:
|
||||
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var slice []T
|
||||
for rows.Next() {
|
||||
v := new(T)
|
||||
if err := rows.StructScan(v); err != nil {
|
||||
log.Error("NamedQuerySlice StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
slice = append(slice, *v)
|
||||
@@ -188,6 +278,80 @@ func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.Ext
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryStruct is a helper function for executing queries that return a
|
||||
// single value to be unmarshalled into a struct type where field replacement is necessary.
|
||||
func QueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, dest any) error {
|
||||
return namedQueryStruct(ctx, log, db, query, struct{}{}, dest, false)
|
||||
}
|
||||
|
||||
// NamedQueryStruct is a helper function for executing queries that return a
|
||||
// single value to be unmarshalled into a struct type where field replacement is necessary.
|
||||
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
|
||||
return namedQueryStruct(ctx, log, db, query, data, dest, false)
|
||||
}
|
||||
|
||||
// NamedQueryStructUsingIn is a helper function for executing queries that return
|
||||
// a single value to be unmarshalled into a struct type where field replacement
|
||||
// is necessary. Use this if the query has an IN clause.
|
||||
func NamedQueryStructUsingIn(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) error {
|
||||
return namedQueryStruct(ctx, log, db, query, data, dest, true)
|
||||
}
|
||||
|
||||
func namedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any, withIn bool) (err error) {
|
||||
q := queryString(query, data)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Info("database.NamedQuerySlice", zap.String("query", q), zap.Int("type", 6), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
var rows *sqlx.Rows
|
||||
|
||||
switch withIn {
|
||||
case true:
|
||||
rows, err = func() (*sqlx.Rows, error) {
|
||||
named, args, err := sqlx.Named(query, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, args, err := sqlx.In(named, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query = db.Rebind(query)
|
||||
return db.QueryxContext(ctx, query, args...)
|
||||
}()
|
||||
|
||||
default:
|
||||
rows, err = sqlx.NamedQueryContext(ctx, db, query, data)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
if !rows.Next() {
|
||||
return ErrDBNotFound
|
||||
}
|
||||
|
||||
if err := rows.StructScan(dest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryString provides a pretty print version of the query and parameters.
|
||||
func queryString(query string, args any) string {
|
||||
query, params, err := sqlx.Named(query, args)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user