package system import ( "context" "encoding/json" "errors" "strconv" "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" "management/internal/erpserver/model/form" systemmodel "management/internal/erpserver/model/system" "management/internal/erpserver/model/view" "management/internal/erpserver/store" "management/internal/pkg/crypto" "management/internal/pkg/know" "management/internal/pkg/rand" "management/internal/pkg/session" "github.com/drhin/logger" "github.com/google/uuid" "go.uber.org/zap" ) // UserBiz 定义处理用户请求所需的方法. type UserBiz interface { Create(ctx context.Context, req *form.User) error Update(ctx context.Context, req *form.User) error All(ctx context.Context) ([]*db.SysUser, error) List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error) Get(ctx context.Context, id int32) (*db.SysUser, error) XmSelect(ctx context.Context) ([]*view.XmSelect, error) UserExpansion } // UserExpansion 定义用户操作的扩展方法. type UserExpansion interface { Login(ctx context.Context, req *form.Login) error } // userBiz 是 UserBiz 接口的实现. type userBiz struct { database store.IStore store db.Store session session.Session log *logger.Logger } // 确保 userBiz 实现了 UserBiz 接口. var _ UserBiz = (*userBiz)(nil) func NewUser(database store.IStore, store db.Store, session session.Session, log *logger.Logger) *userBiz { return &userBiz{ database: database, store: store, session: session, log: log, } } func (b *userBiz) Create(ctx context.Context, req *form.User) error { salt, err := rand.String(10) if err != nil { return err } hashedPassword, err := crypto.BcryptHashPassword(req.Password + salt) if err != nil { return err } initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local) if err != nil { return err } arg := &db.CreateSysUserParams{ Uuid: uuid.Must(uuid.NewV7()), Email: req.Email, Username: req.Username, HashedPassword: hashedPassword, Salt: salt, Avatar: req.Avatar, Gender: req.Gender, DepartmentID: req.DepartmentID, RoleID: req.RoleID, Status: *req.Status, ChangePasswordAt: initTime, CreatedAt: time.Now(), UpdatedAt: time.Now(), } _, err = b.store.CreateSysUser(ctx, arg) if err != nil { if db.IsUniqueViolation(err) { return errors.New("用户已经存在") } return err } return nil } func (b *userBiz) Update(ctx context.Context, req *form.User) error { user, err := b.store.GetSysUser(ctx, *req.ID) if err != nil { return err } arg := &db.UpdateSysUserParams{ ID: user.ID, Username: req.Username, HashedPassword: user.HashedPassword, Avatar: req.Avatar, Gender: req.Gender, DepartmentID: req.DepartmentID, RoleID: req.RoleID, Status: *req.Status, ChangePasswordAt: user.ChangePasswordAt, UpdatedAt: time.Now(), } if req.ChangePassword == "on" { hashedPassword, err := crypto.BcryptHashPassword(req.Password + user.Salt) if err != nil { return err } arg.HashedPassword = hashedPassword arg.ChangePasswordAt = time.Now() } _, err = b.store.UpdateSysUser(ctx, arg) return err } func (b *userBiz) All(ctx context.Context) ([]*db.SysUser, error) { return b.store.ListSysUser(ctx) } func (b *userBiz) List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error) { count, err := b.store.CountSysUserCondition(ctx, &db.CountSysUserConditionParams{ IsStatus: q.SearchStatus != 9999, Status: int32(q.SearchStatus), IsID: q.SearchID != 0, ID: int32(q.SearchID), Username: q.SearchName, Email: q.SearchEmail, }) if err != nil { return nil, 0, err } users, err := b.store.ListSysUserCondition(ctx, &db.ListSysUserConditionParams{ IsStatus: q.SearchStatus != 9999, Status: int32(q.SearchStatus), IsID: q.SearchID != 0, ID: int32(q.SearchID), Username: q.SearchName, Email: q.SearchEmail, Skip: (int32(q.Page) - 1) * int32(q.Rows), Size: int32(q.Rows), }) if err != nil { return nil, 0, err } return users, count, nil } func (b *userBiz) Get(ctx context.Context, id int32) (*db.SysUser, error) { return b.store.GetSysUser(ctx, id) } func (b *userBiz) XmSelect(ctx context.Context) ([]*view.XmSelect, error) { all, err := b.store.ListSysUser(ctx) if err != nil || len(all) == 0 { return nil, err } var res []*view.XmSelect for _, user := range all { res = append(res, &view.XmSelect{ Name: user.Username, Value: strconv.Itoa(int(user.ID)), }) } return res, nil } func (b *userBiz) Login(ctx context.Context, req *form.Login) error { l := systemmodel.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer) err := b.login(ctx, req) if err != nil { if err := b.database.LoginLog().Create(ctx, l.SetMessage(err.Error())); err != nil { b.log.Error(err.Error(), err, zap.Any("login_log", l)) } return err } if err := b.database.LoginLog().Create(ctx, l.SetOk("登录成功")); err != nil { b.log.Error(err.Error(), err, zap.Any("login_log", l)) } return nil } func (b *userBiz) login(ctx context.Context, req *form.Login) error { user, err := b.database.User().GetByEmail(ctx, req.Email) if err != nil { return err } err = crypto.BcryptComparePassword(user.HashedPassword, req.Password+user.Salt) if err != nil { return errors.New("账号或密码错误") } user.Role, err = b.database.Role().Get(ctx, user.RoleID) if err != nil { return err } if user.Role == nil || user.Role.ID == 0 { return errors.New("账号没有配置角色, 请联系管理员") } // 登陆成功 err = b.loginSuccess(ctx, user, req) if err != nil { return err } return nil } func (b *userBiz) loginSuccess(ctx context.Context, user *systemmodel.User, req *form.Login) error { auth := dto.AuthorizeUser{ ID: user.ID, Uuid: user.Uuid, Email: user.Email, Username: user.Username, RoleID: user.Role.ID, RoleName: user.Role.Name, OS: req.Os, IP: req.Ip, Browser: req.Browser, } gob, err := json.Marshal(auth) if err != nil { return err } b.session.Put(ctx, know.StoreName, gob) return nil }