This commit is contained in:
2025-06-18 17:44:49 +08:00
parent b171122a32
commit 0878a4e6de
66 changed files with 2841 additions and 1423 deletions

View File

@@ -1,196 +1,198 @@
package mid
import (
"context"
"errors"
"net/http"
"sync"
"time"
//import (
// "context"
// "errors"
// "net/http"
// "sync"
// "time"
//
// systemmodel "management/internal/erpserver/model/system"
// v1 "management/internal/erpserver/service/v1"
// "management/internal/pkg/know"
// "management/internal/pkg/session"
//
// "github.com/drhin/logger"
// "go.uber.org/zap"
//)
//
//// AuditBuffer 审计日志缓冲器
//type AuditBuffer struct {
// auditLogService v1.AuditLogService
// log *logger.Logger
// buffer chan *systemmodel.AuditLog
// stopCh chan struct{}
// wg sync.WaitGroup
// batchSize int
// flushInterval time.Duration
//}
//
//// NewAuditBuffer 创建审计日志缓冲器
//func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer {
// return &AuditBuffer{
// auditLogService: auditLogService,
// log: log,
// buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小
// stopCh: make(chan struct{}),
// batchSize: 50, // 批量大小
// flushInterval: 3 * time.Second, // 刷新间隔
// }
//}
//
//// Start 启动缓冲器
//func (ab *AuditBuffer) Start() {
// ab.wg.Add(1)
// go ab.processBuffer()
//}
//
//// Stop 停止缓冲器
//func (ab *AuditBuffer) Stop() {
// close(ab.stopCh)
// ab.wg.Wait()
// close(ab.buffer)
//}
//
//// Add 添加审计日志到缓冲区
//func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) {
// select {
// case ab.buffer <- auditLog:
// // 成功添加到缓冲区
// default:
// // 缓冲区满,记录警告但不阻塞
// ab.log.Warn("审计日志缓冲区已满,丢弃日志")
// }
//}
//
//// processBuffer 处理缓冲区中的日志
//func (ab *AuditBuffer) processBuffer() {
// defer ab.wg.Done()
//
// ticker := time.NewTicker(ab.flushInterval)
// defer ticker.Stop()
//
// batch := make([]*systemmodel.AuditLog, 0, ab.batchSize)
//
// flushBatch := func() {
// if len(batch) == 0 {
// return
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// defer cancel()
//
// // 批量插入
// if err := ab.batchInsert(ctx, batch); err != nil {
// ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch)))
// } else {
// ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch)))
// }
//
// // 清空批次
// batch = batch[:0]
// }
//
// for {
// select {
// case <-ab.stopCh:
// // 停止信号,处理剩余的日志
// for len(ab.buffer) > 0 {
// select {
// case auditLog := <-ab.buffer:
// batch = append(batch, auditLog)
// if len(batch) >= ab.batchSize {
// flushBatch()
// }
// default:
// break
// }
// }
// flushBatch() // 处理最后一批
// return
//
// case <-ticker.C:
// // 定时刷新
// flushBatch()
//
// case auditLog := <-ab.buffer:
// // 收到新的审计日志
// batch = append(batch, auditLog)
// if len(batch) >= ab.batchSize {
// flushBatch()
// }
// }
// }
//}
//
//// batchInsert 批量插入数据库
//func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
// maxRetries := 3
// for i := 0; i < maxRetries; i++ {
// // 假设你的服务有批量创建方法,如果没有,需要添加
// if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil {
// if i == maxRetries-1 {
// return err
// }
// ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1))
// time.Sleep(time.Duration(i+1) * time.Second)
// continue
// }
// return nil
// }
// return nil
//}
//
//// 全局缓冲器实例
//var globalAuditBuffer *AuditBuffer
//
//// InitAuditBuffer 初始化全局缓冲器
//func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) {
// globalAuditBuffer = NewAuditBuffer(auditLogService, log)
// globalAuditBuffer.Start()
//}
//
//// StopAuditBuffer 停止全局缓冲器
//func StopAuditBuffer() {
// if globalAuditBuffer != nil {
// globalAuditBuffer.Stop()
// }
//}
//
//// Audit 优化后的中间件
//func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// start := time.Now()
//
// // 提前获取用户信息
// user, err := sess.GetUser(r.Context(), know.StoreName)
// if err != nil {
// log.Error("获取用户会话失败", err)
// next.ServeHTTP(w, r)
// return
// }
//
// // 处理请求
// next.ServeHTTP(w, r)
//
// // 异步添加到缓冲区
// go func() {
// if user.ID == 0 {
// log.Error("用户信息为空", errors.New("user is empty"))
// return
// }
//
// auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
//
// // 添加到缓冲区,不会阻塞
// if globalAuditBuffer != nil {
// globalAuditBuffer.Add(auditLog)
// }
// }()
// })
// }
//}
systemmodel "management/internal/erpserver/model/system"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/drhin/logger"
"go.uber.org/zap"
)
// AuditBuffer 审计日志缓冲器
type AuditBuffer struct {
auditLogService v1.AuditLogService
log *logger.Logger
buffer chan *systemmodel.AuditLog
stopCh chan struct{}
wg sync.WaitGroup
batchSize int
flushInterval time.Duration
}
// NewAuditBuffer 创建审计日志缓冲器
func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer {
return &AuditBuffer{
auditLogService: auditLogService,
log: log,
buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小
stopCh: make(chan struct{}),
batchSize: 50, // 批量大小
flushInterval: 3 * time.Second, // 刷新间隔
}
}
// Start 启动缓冲器
func (ab *AuditBuffer) Start() {
ab.wg.Add(1)
go ab.processBuffer()
}
// Stop 停止缓冲器
func (ab *AuditBuffer) Stop() {
close(ab.stopCh)
ab.wg.Wait()
close(ab.buffer)
}
// Add 添加审计日志到缓冲区
func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) {
select {
case ab.buffer <- auditLog:
// 成功添加到缓冲区
default:
// 缓冲区满,记录警告但不阻塞
ab.log.Warn("审计日志缓冲区已满,丢弃日志")
}
}
// processBuffer 处理缓冲区中的日志
func (ab *AuditBuffer) processBuffer() {
defer ab.wg.Done()
ticker := time.NewTicker(ab.flushInterval)
defer ticker.Stop()
batch := make([]*systemmodel.AuditLog, 0, ab.batchSize)
flushBatch := func() {
if len(batch) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 批量插入
if err := ab.batchInsert(ctx, batch); err != nil {
ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch)))
} else {
ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch)))
}
// 清空批次
batch = batch[:0]
}
for {
select {
case <-ab.stopCh:
// 停止信号,处理剩余的日志
for len(ab.buffer) > 0 {
select {
case auditLog := <-ab.buffer:
batch = append(batch, auditLog)
if len(batch) >= ab.batchSize {
flushBatch()
}
default:
break
}
}
flushBatch() // 处理最后一批
return
case <-ticker.C:
// 定时刷新
flushBatch()
case auditLog := <-ab.buffer:
// 收到新的审计日志
batch = append(batch, auditLog)
if len(batch) >= ab.batchSize {
flushBatch()
}
}
}
}
// batchInsert 批量插入数据库
func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
maxRetries := 3
for i := 0; i < maxRetries; i++ {
// 假设你的服务有批量创建方法,如果没有,需要添加
if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil {
if i == maxRetries-1 {
return err
}
ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1))
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
return nil
}
return nil
}
// 全局缓冲器实例
var globalAuditBuffer *AuditBuffer
// InitAuditBuffer 初始化全局缓冲器
func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) {
globalAuditBuffer = NewAuditBuffer(auditLogService, log)
globalAuditBuffer.Start()
}
// StopAuditBuffer 停止全局缓冲器
func StopAuditBuffer() {
if globalAuditBuffer != nil {
globalAuditBuffer.Stop()
}
}
// Audit 优化后的中间件
func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 提前获取用户信息
user, err := sess.GetUser(r.Context(), know.StoreName)
if err != nil {
log.Error("获取用户会话失败", err)
next.ServeHTTP(w, r)
return
}
// 处理请求
next.ServeHTTP(w, r)
// 异步添加到缓冲区
go func() {
if user.ID == 0 {
log.Error("用户信息为空", errors.New("user is empty"))
return
}
auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
// 添加到缓冲区,不会阻塞
if globalAuditBuffer != nil {
globalAuditBuffer.Add(auditLog)
}
}()
})
}
}
// ======================================================
// 如果你的AuditLogService没有BatchCreate方法需要添加这个接口
// 在你的service接口中添加

View File

@@ -0,0 +1,57 @@
package mid
import (
"context"
"net/http"
"time"
"management/internal/erpserver/model/system"
"management/internal/pkg/know"
"management/internal/pkg/session"
"management/internal/tasks"
"github.com/drhin/logger"
"github.com/hibiken/asynq"
"go.uber.org/zap"
)
// Audit 改造后的中间件
func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
start := time.Now()
user, err := sess.GetUser(ctx, know.StoreName)
if err != nil {
log.Error("获取用户会话失败", err)
}
next.ServeHTTP(w, r)
if user.ID == 0 {
return
}
payload := &tasks.PayloadConsumeAuditLog{
AuditLog: system.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now()),
}
opts := []asynq.Option{
asynq.MaxRetry(10),
asynq.ProcessIn(1 * time.Second),
asynq.Queue(tasks.QueueCritical),
}
c, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
if err := task.DistributeTaskConsumeAuditLog(c, payload, opts...); err != nil {
log.Error("distribute task failed", err,
zap.String("type", "audit"),
zap.Any("payload", payload),
)
}
})
}
}

View File

@@ -7,41 +7,39 @@ import (
"net/http"
"sync"
"time"
"unsafe"
"management/internal/erpserver/model/dto"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/json-iterator/go"
"github.com/patrickmn/go-cache"
"golang.org/x/sync/singleflight"
)
// 高性能JSON库全局初始化
var json = jsoniter.ConfigFastest
// 使用jsoniter优化菜单结构体序列化
func init() {
jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
m := (*dto.OwnerMenuDto)(ptr)
stream.WriteObjectStart()
stream.WriteObjectField("id")
stream.WriteUint(uint(m.ID))
stream.WriteMore()
stream.WriteObjectField("url")
stream.WriteString(m.Url)
stream.WriteMore()
stream.WriteObjectField("parentId")
stream.WriteUint(uint(m.ParentID))
stream.WriteMore()
stream.WriteObjectField("isList")
stream.WriteBool(m.IsList)
stream.WriteObjectEnd()
}, nil)
}
//// 高性能JSON库全局初始化
//var json = jsoniter.ConfigFastest
//
//// 使用jsoniter优化菜单结构体序列化
//func init() {
// jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
// func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
// m := (*dto.OwnerMenuDto)(ptr)
// stream.WriteObjectStart()
// stream.WriteObjectField("id")
// stream.WriteUint(uint(m.ID))
// stream.WriteMore()
// stream.WriteObjectField("url")
// stream.WriteString(m.Url)
// stream.WriteMore()
// stream.WriteObjectField("parentId")
// stream.WriteUint(uint(m.ParentID))
// stream.WriteMore()
// stream.WriteObjectField("isList")
// stream.WriteBool(m.IsList)
// stream.WriteObjectEnd()
// }, nil)
//}
var publicRoutes = map[string]bool{
"/home.html": true,

View File

@@ -0,0 +1,95 @@
package mid
//var publicRoutes = map[string]bool{
// "/home.html": true,
// "/dashboard": true,
// "/system/menus": true,
// "/upload/img": true,
// "/upload/file": true,
// "/upload/multi_files": true,
// "/system/pear.json": true,
// "/logout": true,
//}
//
//var m sync.Map
//
//func Authorize(
// sess session.Manager,
// menuService v1.MenuService,
//) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// path := r.URL.Path
//
// // 登陆检查
// user, err := sess.GetUser(ctx, know.StoreName)
// if err != nil || user.ID == 0 {
// http.Redirect(w, r, "/", http.StatusFound)
// return
// }
//
// // 公共路由放行
// if publicRoutes[path] {
// ctx = setUser(ctx, user)
// next.ServeHTTP(w, r.WithContext(ctx))
// return
// }
//
// n1 := time.Now()
// // 权限检查
// var menus map[string]*dto.OwnerMenuDto
// cacheKey := fmt.Sprintf("user_menus:%d", user.RoleID)
// if value, ok := m.Load(cacheKey); ok {
// menus = value.(map[string]*dto.OwnerMenuDto)
// log.Printf("map (from cache): %s", time.Since(n1).String())
// } else {
// menus, err = menuService.ListByRoleIDToMap(ctx, user.RoleID)
// if err == nil {
// m.Store(cacheKey, menus)
// }
// log.Printf("map (from DB, then cached): %s", time.Since(n1).String())
// }
//
// if !hasPermission(menus, path) {
// http.Error(w, "Forbidden", http.StatusForbidden)
// return
// }
//
// cur := getCurrentMenus(menus, path)
//
// ctx = setUser(ctx, user)
// ctx = setCurMenus(ctx, cur)
//
// next.ServeHTTP(w, r.WithContext(ctx))
// })
// }
//}
//
//func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
// _, ok := menus[path]
// return ok
//}
//
//func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
// var res []dto.OwnerMenuDto
//
// menu, ok := data[path]
// if !ok {
// return res
// }
//
// for _, item := range data {
// if menu.IsList {
// if item.ParentID == menu.ID || item.ID == menu.ID {
// res = append(res, *item)
// }
// } else {
// if item.ParentID == menu.ParentID {
// res = append(res, *item)
// }
// }
// }
//
// return res
//}

View File

@@ -2,8 +2,10 @@ package mid
import (
"context"
"errors"
"management/internal/erpserver/model/dto"
"management/internal/pkg/sqldb"
"github.com/a-h/templ"
)
@@ -73,3 +75,19 @@ func GetHtmlCsrfToken(ctx context.Context) templ.Component {
return v
}
type trkey struct{}
func setTran(ctx context.Context, tx sqldb.CommitRollbacker) context.Context {
return context.WithValue(ctx, trkey{}, tx)
}
// GetTran retrieves the value that can manage a transaction.
func GetTran(ctx context.Context) (sqldb.CommitRollbacker, error) {
v, ok := ctx.Value(trkey{}).(sqldb.CommitRollbacker)
if !ok {
return nil, errors.New("transaction not found in context")
}
return v, nil
}

View File

@@ -0,0 +1,57 @@
package mid
import (
"database/sql"
"errors"
"net/http"
"management/internal/pkg/sqldb"
"github.com/drhin/logger"
"go.uber.org/zap"
)
func BeginCommitRollback(log *logger.Logger, bgn sqldb.Beginner) func(http.Handler) http.Handler {
m := func(next http.Handler) http.Handler {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hasCommitted := false
log.Info("BEGIN TRANSACTION")
tx, err := bgn.Begin()
if err != nil {
log.Error("BEGIN TRANSACTION", err)
return
}
defer func() {
if !hasCommitted {
log.Info("ROLLBACK TRANSACTION")
}
if err := tx.Rollback(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
return
}
log.Info("ROLLBACK TRANSACTION", zap.Error(err))
}
}()
ctx := r.Context()
ctx = setTran(ctx, tx)
next.ServeHTTP(w, r.WithContext(ctx))
log.Info("COMMIT TRANSACTION")
if err := tx.Commit(); err != nil {
log.Error("COMMIT TRANSACTION", err)
return
}
hasCommitted = true
})
return h
}
return m
}

214
internal/pkg/sqldb/sqldb.go Normal file
View File

@@ -0,0 +1,214 @@
package sqldb
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"management/internal/pkg/config"
"github.com/drhin/logger"
"github.com/jackc/pgx/v5/pgconn"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
)
const (
uniqueViolation = "23505"
undefinedTable = "42P01"
)
var (
ErrDBNotFound = sql.ErrNoRows
ErrDBDuplicatedEntry = errors.New("duplicated entry")
ErrUndefinedTable = errors.New("undefined table")
)
type Config struct {
User string
Password string
Host string
Port int
Name string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
config.DB.Username,
config.DB.Password,
config.DB.Host,
config.DB.Port,
config.DB.DBName,
)
db, err := sqlx.Open("pgx", dsn)
if err != nil {
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, nil, err
}
// 设置最大空闲连接数(默认 2)
db.SetMaxIdleConns(config.DB.MaxIdleConns)
// 设置最大打开连接数(默认 0 无限制)
db.SetMaxOpenConns(config.DB.MaxOpenConns)
// 设置连接最大存活时间
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
// 设置连接最大空闲时间
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
cleanup := func() {
if err := db.Close(); err != nil {
log.Error("sql db close error", err)
}
}
return db, cleanup, nil
}
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
defer func() {
if err != nil {
switch data.(type) {
case struct{}:
log.Error("database.NamedExecContext (data is struct)", err,
zap.String("query", query),
zap.Any("ERROR", err))
default:
log.Error("database.NamedExecContext", err,
zap.String("query", query),
zap.Any("ERROR", err))
}
}
}()
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
var pgError *pgconn.PgError
if errors.As(err, &pgError) {
switch pgError.Code {
case undefinedTable:
return ErrUndefinedTable
case uniqueViolation:
return ErrDBDuplicatedEntry
}
}
return err
}
return nil
}
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
}(rows)
if !rows.Next() {
return ErrDBNotFound
}
if err := rows.StructScan(dest); err != nil {
log.Error("NamedQueryStruct StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
return nil
}
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
q := queryString(query, data)
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
if err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
return ErrUndefinedTable
}
log.Error("NamedQueryStruct NamedQueryContext error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
defer func(rows *sqlx.Rows) {
err := rows.Close()
if err != nil {
log.Error("rows close error", err)
}
}(rows)
var slice []T
for rows.Next() {
v := new(T)
if err := rows.StructScan(v); err != nil {
log.Error("NamedQuerySlice StructScan error", err,
zap.String("query", q),
zap.Any("data", data),
)
return err
}
slice = append(slice, *v)
}
*dest = slice
return nil
}
func queryString(query string, args any) string {
query, params, err := sqlx.Named(query, args)
if err != nil {
return err.Error()
}
for _, param := range params {
var value string
switch v := param.(type) {
case string:
value = fmt.Sprintf("'%s'", v)
case []byte:
value = fmt.Sprintf("'%s'", string(v))
default:
value = fmt.Sprintf("%v", v)
}
query = strings.Replace(query, "?", value, 1)
}
query = strings.ReplaceAll(query, "\t", "")
query = strings.ReplaceAll(query, "\n", " ")
return strings.Trim(query, " ")
}

View File

@@ -0,0 +1,49 @@
package sqldb
import (
"fmt"
"github.com/jmoiron/sqlx"
)
// Beginner represents a value that can begin a transaction.
type Beginner interface {
Begin() (CommitRollbacker, error)
}
// CommitRollbacker represents a value that can commit or rollback a transaction.
type CommitRollbacker interface {
Commit() error
Rollback() error
}
// =============================================================================
// DBBeginner implements the Beginner interface,
type DBBeginner struct {
sqlxDB *sqlx.DB
}
// NewBeginner constructs a value that implements the beginner interface.
func NewBeginner(sqlxDB *sqlx.DB) *DBBeginner {
return &DBBeginner{
sqlxDB: sqlxDB,
}
}
// Begin implements the Beginner interface and returns a concrete value that
// implements the CommitRollbacker interface.
func (db *DBBeginner) Begin() (CommitRollbacker, error) {
return db.sqlxDB.Beginx()
}
// GetExtContext is a helper function that extracts the sqlx value
// from the domain transactor interface for transactional use.
func GetExtContext(tx CommitRollbacker) (sqlx.ExtContext, error) {
ec, ok := tx.(sqlx.ExtContext)
if !ok {
return nil, fmt.Errorf("Transactor(%T) not of a type *sql.Tx", tx)
}
return ec, nil
}