2025-04-14 15:28:51 +08:00

254 lines
6.3 KiB
Go

package system
import (
"context"
"encoding/json"
"errors"
"strconv"
"time"
"management/internal/db/model/dto"
db "management/internal/db/sqlc"
"management/internal/erpserver/model/form"
systemmodel "management/internal/erpserver/model/system"
"management/internal/erpserver/model/view"
"management/internal/erpserver/store"
"management/internal/pkg/crypto"
"management/internal/pkg/know"
"management/internal/pkg/rand"
"management/internal/pkg/session"
"github.com/drhin/logger"
"github.com/google/uuid"
"go.uber.org/zap"
)
// UserBiz 定义处理用户请求所需的方法.
type UserBiz interface {
Create(ctx context.Context, req *form.User) error
Update(ctx context.Context, req *form.User) error
All(ctx context.Context) ([]*db.SysUser, error)
List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error)
Get(ctx context.Context, id int32) (*db.SysUser, error)
XmSelect(ctx context.Context) ([]*view.XmSelect, error)
UserExpansion
}
// UserExpansion 定义用户操作的扩展方法.
type UserExpansion interface {
Login(ctx context.Context, req *form.Login) error
}
// userBiz 是 UserBiz 接口的实现.
type userBiz struct {
database store.IStore
store db.Store
session session.Session
log *logger.Logger
}
// 确保 userBiz 实现了 UserBiz 接口.
var _ UserBiz = (*userBiz)(nil)
func NewUser(database store.IStore, store db.Store, session session.Session, log *logger.Logger) *userBiz {
return &userBiz{
database: database,
store: store,
session: session,
log: log,
}
}
func (b *userBiz) Create(ctx context.Context, req *form.User) error {
salt, err := rand.String(10)
if err != nil {
return err
}
hashedPassword, err := crypto.BcryptHashPassword(req.Password + salt)
if err != nil {
return err
}
initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local)
if err != nil {
return err
}
arg := &db.CreateSysUserParams{
Uuid: uuid.Must(uuid.NewV7()),
Email: req.Email,
Username: req.Username,
HashedPassword: hashedPassword,
Salt: salt,
Avatar: req.Avatar,
Gender: req.Gender,
DepartmentID: req.DepartmentID,
RoleID: req.RoleID,
Status: *req.Status,
ChangePasswordAt: initTime,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_, err = b.store.CreateSysUser(ctx, arg)
if err != nil {
if db.IsUniqueViolation(err) {
return errors.New("用户已经存在")
}
return err
}
return nil
}
func (b *userBiz) Update(ctx context.Context, req *form.User) error {
user, err := b.store.GetSysUser(ctx, *req.ID)
if err != nil {
return err
}
arg := &db.UpdateSysUserParams{
ID: user.ID,
Username: req.Username,
HashedPassword: user.HashedPassword,
Avatar: req.Avatar,
Gender: req.Gender,
DepartmentID: req.DepartmentID,
RoleID: req.RoleID,
Status: *req.Status,
ChangePasswordAt: user.ChangePasswordAt,
UpdatedAt: time.Now(),
}
if req.ChangePassword == "on" {
hashedPassword, err := crypto.BcryptHashPassword(req.Password + user.Salt)
if err != nil {
return err
}
arg.HashedPassword = hashedPassword
arg.ChangePasswordAt = time.Now()
}
_, err = b.store.UpdateSysUser(ctx, arg)
return err
}
func (b *userBiz) All(ctx context.Context) ([]*db.SysUser, error) {
return b.store.ListSysUser(ctx)
}
func (b *userBiz) List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error) {
count, err := b.store.CountSysUserCondition(ctx, &db.CountSysUserConditionParams{
IsStatus: q.SearchStatus != 9999,
Status: int32(q.SearchStatus),
IsID: q.SearchID != 0,
ID: int32(q.SearchID),
Username: q.SearchName,
Email: q.SearchEmail,
})
if err != nil {
return nil, 0, err
}
users, err := b.store.ListSysUserCondition(ctx, &db.ListSysUserConditionParams{
IsStatus: q.SearchStatus != 9999,
Status: int32(q.SearchStatus),
IsID: q.SearchID != 0,
ID: int32(q.SearchID),
Username: q.SearchName,
Email: q.SearchEmail,
Skip: (int32(q.Page) - 1) * int32(q.Rows),
Size: int32(q.Rows),
})
if err != nil {
return nil, 0, err
}
return users, count, nil
}
func (b *userBiz) Get(ctx context.Context, id int32) (*db.SysUser, error) {
return b.store.GetSysUser(ctx, id)
}
func (b *userBiz) XmSelect(ctx context.Context) ([]*view.XmSelect, error) {
all, err := b.store.ListSysUser(ctx)
if err != nil || len(all) == 0 {
return nil, err
}
var res []*view.XmSelect
for _, user := range all {
res = append(res, &view.XmSelect{
Name: user.Username,
Value: strconv.Itoa(int(user.ID)),
})
}
return res, nil
}
func (b *userBiz) Login(ctx context.Context, req *form.Login) error {
l := systemmodel.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer)
err := b.login(ctx, req)
if err != nil {
if err := b.database.LoginLog().Create(ctx, l.SetMessage(err.Error())); err != nil {
b.log.Error(err.Error(), err, zap.Any("login_log", l))
}
return err
}
if err := b.database.LoginLog().Create(ctx, l.SetOk("登录成功")); err != nil {
b.log.Error(err.Error(), err, zap.Any("login_log", l))
}
return nil
}
func (b *userBiz) login(ctx context.Context, req *form.Login) error {
user, err := b.database.User().GetByEmail(ctx, req.Email)
if err != nil {
return err
}
err = crypto.BcryptComparePassword(user.HashedPassword, req.Password+user.Salt)
if err != nil {
return errors.New("账号或密码错误")
}
user.Role, err = b.database.Role().Get(ctx, user.RoleID)
if err != nil {
return err
}
if user.Role == nil || user.Role.ID == 0 {
return errors.New("账号没有配置角色, 请联系管理员")
}
// 登陆成功
err = b.loginSuccess(ctx, user, req)
if err != nil {
return err
}
return nil
}
func (b *userBiz) loginSuccess(ctx context.Context, user *systemmodel.User, req *form.Login) error {
auth := dto.AuthorizeUser{
ID: user.ID,
Uuid: user.Uuid,
Email: user.Email,
Username: user.Username,
RoleID: user.Role.ID,
RoleName: user.Role.Name,
OS: req.Os,
IP: req.Ip,
Browser: req.Browser,
}
gob, err := json.Marshal(auth)
if err != nil {
return err
}
b.session.Put(ctx, know.StoreName, gob)
return nil
}