From 4186cd0caf871d66d93e40a2ddb0e506426dfa40 Mon Sep 17 00:00:00 2001 From: kenneth <1185230223@qq.com> Date: Wed, 25 Jun 2025 16:11:03 +0800 Subject: [PATCH] update --- cmd/build/build.go | 3 + internal/erpserver/handler/route.go | 4 + .../erpserver/handler/system/auth/auth.go | 9 +- .../erpserver/handler/system/auth/route.go | 4 +- .../model/{dto => system}/authorize_user.go | 17 +- internal/erpserver/model/system/login_log.go | 1 + .../repository/system/loginlog/login_log.go | 42 ++ internal/erpserver/service/v1/auth/auth.go | 571 ++++++++++++++++++ internal/erpserver/service/v1/service.go | 4 + .../erpserver/service/v1/system/login_log.go | 9 + internal/erpserver/service/v1/system/user.go | 22 +- internal/erpserver/templ/auth/login.templ | 1 + internal/erpserver/templ/auth/login_templ.go | 2 +- internal/pkg/mid/audit_v3.go | 2 +- internal/pkg/mid/mid.go | 9 +- internal/pkg/session/session.go | 16 +- 16 files changed, 690 insertions(+), 26 deletions(-) rename internal/erpserver/model/{dto => system}/authorize_user.go (54%) create mode 100644 internal/erpserver/service/v1/auth/auth.go diff --git a/cmd/build/build.go b/cmd/build/build.go index 9cde3f3..a22c3d0 100644 --- a/cmd/build/build.go +++ b/cmd/build/build.go @@ -15,6 +15,7 @@ import ( "management/internal/erpserver/repository/system/rolemenu" "management/internal/erpserver/repository/system/user" v1 "management/internal/erpserver/service/v1" + "management/internal/erpserver/service/v1/auth" "management/internal/erpserver/service/v1/common" system2 "management/internal/erpserver/service/v1/system" "management/internal/pkg/cache" @@ -84,6 +85,7 @@ func Initialize(conf *config.Config, log *logger.Logger) (http.Handler, func(), configService := system2.NewConfigService(service, configRepository) captchaService := common.NewCaptchaService() auditLogService := system2.NewAuditLogService(service, auditLogRepository) + authService := auth.NewAuth(log, rd, sm, userService, roleService, loginLogService) // ================================================================================================================= // task @@ -117,6 +119,7 @@ func Initialize(conf *config.Config, log *logger.Logger) (http.Handler, func(), Render: rdr, TaskDistributor: taskDistributor, CaptchaService: captchaService, + AuthService: authService, UserService: userService, RoleService: roleService, DepartmentService: departmentService, diff --git a/internal/erpserver/handler/route.go b/internal/erpserver/handler/route.go index f5f235f..2c102bd 100644 --- a/internal/erpserver/handler/route.go +++ b/internal/erpserver/handler/route.go @@ -16,6 +16,7 @@ import ( "management/internal/erpserver/handler/system/user" "management/internal/erpserver/handler/upload" v1 "management/internal/erpserver/service/v1" + authv1 "management/internal/erpserver/service/v1/auth" "management/internal/pkg/config" "management/internal/pkg/mid" "management/internal/pkg/render" @@ -34,6 +35,7 @@ type Config struct { Render render.Renderer TaskDistributor tasks.TaskDistributor CaptchaService v1.CaptchaService + AuthService *authv1.Auth UserService v1.UserService RoleService v1.RoleService DepartmentService v1.DepartmentService @@ -90,6 +92,7 @@ func WebApp(cfg Config) http.Handler { Sm: cfg.Sm, Render: cfg.Render, CaptchaService: cfg.CaptchaService, + AuthService: cfg.AuthService, UserService: cfg.UserService, MenuService: cfg.MenuService, }) @@ -130,6 +133,7 @@ func WebApp(cfg Config) http.Handler { Log: cfg.Log, Sm: cfg.Sm, Render: cfg.Render, + TaskDistributor: cfg.TaskDistributor, MenuService: cfg.MenuService, DepartmentService: cfg.DepartmentService, }) diff --git a/internal/erpserver/handler/system/auth/auth.go b/internal/erpserver/handler/system/auth/auth.go index b1dd1b9..9e2f392 100644 --- a/internal/erpserver/handler/system/auth/auth.go +++ b/internal/erpserver/handler/system/auth/auth.go @@ -7,6 +7,7 @@ import ( "management/internal/erpserver/model/form" v1 "management/internal/erpserver/service/v1" + authv1 "management/internal/erpserver/service/v1/auth" "management/internal/erpserver/templ/auth" "management/internal/pkg/binding" "management/internal/pkg/mid" @@ -22,6 +23,7 @@ type app struct { render render.Renderer captchaService v1.CaptchaService userService v1.UserService + authService *authv1.Auth } func newApp( @@ -30,6 +32,7 @@ func newApp( render render.Renderer, captchaService v1.CaptchaService, userService v1.UserService, + authService *authv1.Auth, ) *app { return &app{ log: log, @@ -37,6 +40,7 @@ func newApp( render: render, captchaService: captchaService, userService: userService, + authService: authService, } } @@ -75,13 +79,16 @@ func (a *app) login(w http.ResponseWriter, r *http.Request) { } req = req.SetAttributes(r) - err := a.userService.Login(ctx, &req) + //err := a.userService.Login(ctx, &req) + risk, err := a.authService.Authenticate(ctx, req) if err != nil { log.Println(err) a.render.JSONErr(w, err.Error()) return } + log.Println(risk) + a.render.JSONOk(w, "login successfully") default: http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) diff --git a/internal/erpserver/handler/system/auth/route.go b/internal/erpserver/handler/system/auth/route.go index 2b7ed1c..1114bba 100644 --- a/internal/erpserver/handler/system/auth/route.go +++ b/internal/erpserver/handler/system/auth/route.go @@ -2,6 +2,7 @@ package auth import ( v1 "management/internal/erpserver/service/v1" + "management/internal/erpserver/service/v1/auth" "management/internal/pkg/mid" "management/internal/pkg/render" "management/internal/pkg/session" @@ -15,12 +16,13 @@ type Config struct { Sm session.Manager Render render.Renderer CaptchaService v1.CaptchaService + AuthService *auth.Auth UserService v1.UserService MenuService v1.MenuService } func Routes(r chi.Router, cfg Config) { - app := newApp(cfg.Log, cfg.Sm, cfg.Render, cfg.CaptchaService, cfg.UserService) + app := newApp(cfg.Log, cfg.Sm, cfg.Render, cfg.CaptchaService, cfg.UserService, cfg.AuthService) r.Get("/", app.login) r.Post("/login", app.login) diff --git a/internal/erpserver/model/dto/authorize_user.go b/internal/erpserver/model/system/authorize_user.go similarity index 54% rename from internal/erpserver/model/dto/authorize_user.go rename to internal/erpserver/model/system/authorize_user.go index 7997904..6c332ca 100644 --- a/internal/erpserver/model/dto/authorize_user.go +++ b/internal/erpserver/model/system/authorize_user.go @@ -1,4 +1,4 @@ -package dto +package system import ( "github.com/google/uuid" @@ -16,3 +16,18 @@ type AuthorizeUser struct { IP string `json:"ip"` Browser string `json:"browser"` } + +func NewAuthorizeUser(user *User, ip, os, browser string) AuthorizeUser { + return AuthorizeUser{ + ID: user.ID, + Uuid: user.Uuid, + Email: user.Email, + Username: user.Username, + Avatar: user.Avatar, + RoleID: user.Role.ID, + RoleName: user.Role.DisplayName, + OS: os, + IP: ip, + Browser: browser, + } +} diff --git a/internal/erpserver/model/system/login_log.go b/internal/erpserver/model/system/login_log.go index e788002..21ed7b6 100644 --- a/internal/erpserver/model/system/login_log.go +++ b/internal/erpserver/model/system/login_log.go @@ -12,6 +12,7 @@ type LoginLogRepository interface { GetLatest(ctx context.Context, email string) ([]*LoginLog, error) Count(ctx context.Context, filter dto.SearchDto) (int64, error) List(ctx context.Context, filter dto.SearchDto) ([]*LoginLog, error) + HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*LoginLog, error) } type LoginLog struct { diff --git a/internal/erpserver/repository/system/loginlog/login_log.go b/internal/erpserver/repository/system/loginlog/login_log.go index 3e1cd63..50c4d34 100644 --- a/internal/erpserver/repository/system/loginlog/login_log.go +++ b/internal/erpserver/repository/system/loginlog/login_log.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "fmt" + "strings" + "time" "management/internal/erpserver/model/dto" "management/internal/erpserver/model/system" @@ -114,6 +116,46 @@ func (s *store) List(ctx context.Context, filter dto.SearchDto) ([]*system.Login return toPointer(logs), nil } +func (s *store) HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*system.LoginLog, error) { + //goland:noinspection ALL + const q = ` + SELECT + id, created_at, email, is_success, message, referer_url, url, os, ip, browser + FROM + sys_user_login_log` + + data := map[string]any{ + "rows_per_page": 20, + } + + buf := bytes.NewBufferString(q) + + var wc []string + + data["start_at"] = createdAt.Format(time.DateTime) + wc = append(wc, "created_at > :start_at") + + if email != "" { + data["email"] = email + wc = append(wc, "email = :email") + } + + if len(wc) > 0 { + buf.WriteString(" WHERE ") + buf.WriteString(strings.Join(wc, " AND ")) + } + + buf.WriteString(" ORDER BY created_at DESC ") + buf.WriteString(" LIMIT :rows_per_page") + + var logs []system.LoginLog + err := sqldb.NamedQuerySlice(ctx, s.log, s.db.DB(ctx), buf.String(), data, &logs) + if err != nil { + return nil, err + } + return toPointer(logs), nil +} + func toPointer(data []system.LoginLog) []*system.LoginLog { var res []*system.LoginLog for _, v := range data { diff --git a/internal/erpserver/service/v1/auth/auth.go b/internal/erpserver/service/v1/auth/auth.go new file mode 100644 index 0000000..6e35bb7 --- /dev/null +++ b/internal/erpserver/service/v1/auth/auth.go @@ -0,0 +1,571 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + "management/internal/erpserver/model/form" + "management/internal/erpserver/model/system" + v1 "management/internal/erpserver/service/v1" + "management/internal/pkg/crypto" + "management/internal/pkg/know" + "management/internal/pkg/session" + + "github.com/drhin/logger" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// 安全配置常量 +const ( + // MaxLoginAttempts 最大登录尝试次数 + MaxLoginAttempts = 5 + // LoginLockoutDuration 锁定时间 + LoginLockoutDuration = 30 * time.Minute + // RiskCheckWindow 风险检查时间窗口 + RiskCheckWindow = 24 * time.Hour + // SessionTimeout 会话超时时间 + SessionTimeout = 2 * time.Hour +) + +// Redis键前缀 +const ( + LoginAttemptsPrefix = "login_attempts:" + LoginLockPrefix = "login_lock:" + SessionPrefix = "session:" + RiskSessionPrefix = "risk_session:" +) + +// RiskLevel 风险等级 +type RiskLevel string + +const ( + RiskLow RiskLevel = "low" + RiskMedium RiskLevel = "medium" + RiskHigh RiskLevel = "high" +) + +// RiskCheckResult 风险检查结果 +type RiskCheckResult struct { + IsRisky bool `json:"is_risky"` + RiskLevel RiskLevel `json:"risk_level"` + RiskReasons []string `json:"risk_reasons"` + RequiresMFA bool `json:"requires_mfa"` +} + +// LoginEnvironment 登录环境信息 +type LoginEnvironment struct { + IP string `json:"ip"` + OS string `json:"os"` + Browser string `json:"browser"` + Location string `json:"location,omitempty"` +} + +// Auth 安全管理器 +type Auth struct { + log *logger.Logger + redis *redis.Client + sm session.Manager + + userService v1.UserService + roleService v1.RoleService + loginLogService v1.LoginLogService +} + +// NewAuth 创建安全管理器 +func NewAuth( + log *logger.Logger, + redis *redis.Client, + sm session.Manager, + userService v1.UserService, + roleService v1.RoleService, + loginLogService v1.LoginLogService, +) *Auth { + return &Auth{ + log: log, + redis: redis, + sm: sm, + userService: userService, + roleService: roleService, + loginLogService: loginLogService, + } +} + +func (a *Auth) Authenticate(ctx context.Context, req form.Login) (*RiskCheckResult, error) { + l := system.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer) + + locked, duration, err := a.isAccountLocked(ctx, req.Email) + if err != nil { + return nil, err + } + + if locked { + return nil, fmt.Errorf("账户已被锁定,请在 %v 后重试", duration.Round(time.Minute)) + } + + user, err := a.validateUser(ctx, req.Email, req.Password) + if err != nil { + // 记录登录失败 + err = a.recordLoginFailure(ctx, l.SetMessage("用户名或密码错误")) + if err != nil { + a.log.Error(err.Error(), err, zap.Any("login_log", l)) + } + + // 获取剩余尝试次数 + remaining, err := a.getRemainingAttempts(ctx, req.Email) + if err == nil && remaining > 0 { + return nil, fmt.Errorf("用户名或密码错误,还有 %d 次尝试机会", remaining) + } + + return nil, errors.New("用户名或密码错误") + } + + if err := a.successfulLogin(ctx, user.Uuid, l.SetOk("校验成功")); err != nil { + return nil, err + } + + // 获取风险评估结果 + risk, err := a.GetUserRisk(ctx, user.Uuid.String()) + if err != nil { + return nil, err + } + + // 如果存在风险,在响应中包含风险信息 + if risk != nil && risk.IsRisky { + if risk.RequiresMFA { + risk.RiskReasons = append(risk.RiskReasons, "检测到异常登录环境,建议启用多重身份验证") + } else { + risk.RiskReasons = append(risk.RiskReasons, "登录成功,检测到新的登录环境") + } + } + + // 设置会话Cookie + au := system.NewAuthorizeUser(user, req.Os, req.Ip, req.Browser) + if err := a.sm.PutUser(ctx, know.StoreName, au); err != nil { + return nil, err + } + + return risk, nil +} + +func (a *Auth) validateUser(ctx context.Context, email, password string) (*system.User, error) { + user, err := a.userService.GetByEmail(ctx, email) + if err != nil { + return nil, err + } + + if err := crypto.BcryptComparePassword(user.HashedPassword, password+user.Salt); err != nil { + return nil, errors.New("账号或密码错误") + } + + user.Role, err = a.roleService.Get(ctx, user.RoleID) + if err != nil { + return nil, err + } + + if user.Role == nil || user.Role.ID == 0 { + return nil, errors.New("账号没有配置角色, 请联系管理员") + } + + return user, nil +} + +// 1. 登录失败次数限制功能 + +// isAccountLocked 检查是否被锁定 +func (a *Auth) isAccountLocked(ctx context.Context, email string) (bool, time.Duration, error) { + lockKey := LoginLockPrefix + email + + // 检查是否存在锁定记录 + ttl, err := a.redis.TTL(ctx, lockKey).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return false, 0, nil // 没有锁定记录 + } + return false, 0, fmt.Errorf("检查锁定状态失败: %w", err) + } + + if ttl > 0 { + return true, ttl, nil // 仍在锁定期间 + } + + return false, 0, nil +} + +// recordLoginFailure 记录登录失败 +func (a *Auth) recordLoginFailure(ctx context.Context, log *system.LoginLog) error { + attemptsKey := LoginAttemptsPrefix + log.Email + + // 增加失败次数 + attempts, err := a.redis.Incr(ctx, attemptsKey).Result() + if err != nil { + return fmt.Errorf("记录登录失败次数失败: %w", err) + } + + // 设置过期时间(首次失败时) + if attempts == 1 { + a.redis.Expire(ctx, attemptsKey, LoginLockoutDuration) + } + + // 记录登录日志 + if err := a.loginLogService.Create(ctx, log); err != nil { + // 日志记录失败不影响主流程,只记录错误 + fmt.Printf("记录登录日志失败: %v\n", err) + } + + // 如果达到最大尝试次数,锁定账户 + if attempts >= MaxLoginAttempts { + lockKey := LoginLockPrefix + log.Email + err := a.redis.Set(ctx, lockKey, "locked", LoginLockoutDuration).Err() + if err != nil { + return fmt.Errorf("锁定账户失败: %w", err) + } + + // 发送安全警告通知(这里可以扩展为邮件通知等) + go a.sendSecurityAlert(log.Email, "账户被锁定", fmt.Sprintf("由于连续%d次登录失败,您的账户已被锁定%v", MaxLoginAttempts, LoginLockoutDuration)) + } + + return nil +} + +// clearLoginFailures 清除登录失败记录 +func (a *Auth) clearLoginFailures(ctx context.Context, email string) error { + attemptsKey := LoginAttemptsPrefix + email + return a.redis.Del(ctx, attemptsKey).Err() +} + +// getRemainingAttempts 获取剩余登录尝试次数 +func (a *Auth) getRemainingAttempts(ctx context.Context, email string) (int, error) { + attemptsKey := LoginAttemptsPrefix + email + attempts, err := a.redis.Get(ctx, attemptsKey).Int() + if err != nil { + if errors.Is(err, redis.Nil) { + return MaxLoginAttempts, nil // 没有失败记录 + } + return 0, fmt.Errorf("获取登录尝试次数失败: %w", err) + } + + remaining := MaxLoginAttempts - attempts + if remaining < 0 { + remaining = 0 + } + + return remaining, nil +} + +// 2. 登录风险检测功能 + +// checkLoginRisk 检查登录风险 +func (a *Auth) checkLoginRisk(ctx context.Context, email, ip, os, browser string) (*RiskCheckResult, error) { + result := &RiskCheckResult{ + IsRisky: false, + RiskLevel: RiskLow, + RiskReasons: []string{}, + RequiresMFA: false, + } + + // 获取用户历史登录环境 + historicalEnvs, err := a.getHistoricalLoginEnvironments(ctx, email) + if err != nil { + return nil, fmt.Errorf("获取历史登录环境失败: %w", err) + } + + // 如果是首次登录,风险较低 + if len(historicalEnvs) == 0 { + result.RiskLevel = RiskMedium + result.RiskReasons = append(result.RiskReasons, "首次登录") + return result, nil + } + + // 检查IP风险 + if a.isNewIP(ip, historicalEnvs) { + result.IsRisky = true + result.RiskReasons = append(result.RiskReasons, "检测到新的IP地址") + + // 检查IP地理位置变化 + if a.isSignificantLocationChange(ip, historicalEnvs) { + result.RiskLevel = RiskHigh + result.RiskReasons = append(result.RiskReasons, "IP地理位置发生重大变化") + result.RequiresMFA = true + } else { + result.RiskLevel = RiskMedium + } + } + + // 检查浏览器/设备风险 + if a.isNewBrowser(browser, historicalEnvs) { + result.IsRisky = true + result.RiskReasons = append(result.RiskReasons, "检测到新的浏览器或设备") + if result.RiskLevel == RiskLow { + result.RiskLevel = RiskMedium + } + } + + // 检查操作系统风险 + if a.isNewOS(os, historicalEnvs) { + result.IsRisky = true + result.RiskReasons = append(result.RiskReasons, "检测到新的操作系统") + if result.RiskLevel == RiskLow { + result.RiskLevel = RiskMedium + } + } + + // 如果存在多个风险因素,提升风险等级 + if len(result.RiskReasons) >= 2 { + result.RiskLevel = RiskHigh + result.RequiresMFA = true + } + + return result, nil +} + +// successfulLogin 登录成功处理 +func (a *Auth) successfulLogin(ctx context.Context, uuid uuid.UUID, log *system.LoginLog) error { + // 清除登录失败记录 + if err := a.clearLoginFailures(ctx, log.Email); err != nil { + return fmt.Errorf("清除登录失败记录失败: %w", err) + } + + // 记录登录成功日志 + if err := a.recordLoginLog(ctx, log); err != nil { + fmt.Printf("记录登录日志失败: %v\n", err) + } + + // 检查登录风险 + riskResult, err := a.checkLoginRisk(ctx, log.Email, log.Ip, log.Os, log.Browser) + if err != nil { + fmt.Printf("风险检查失败: %v\n", err) + return nil // 风险检查失败不影响正常登录 + } + + // 如果存在风险,记录风险会话 + if riskResult.IsRisky { + riskSessionKey := RiskSessionPrefix + uuid.String() + riskData, err := json.Marshal(riskResult) + if err != nil { + return err + } + + if err := a.redis.Set(ctx, riskSessionKey, riskData, SessionTimeout).Err(); err != nil { + return err + } + + // 发送安全提醒 + go a.sendSecurityAlert(log.Email, "检测到异常登录", + fmt.Sprintf("风险等级: %s, 风险原因: %s", riskResult.RiskLevel, strings.Join(riskResult.RiskReasons, ", "))) + } + + return nil +} + +// 3. 会话安全检查 + +// ValidateSession 验证会话并检查安全性 +func (a *Auth) ValidateSession(ctx context.Context, sessionID string, currentEnv LoginEnvironment) (*system.AuthorizeUser, *RiskCheckResult, error) { + sessionKey := SessionPrefix + sessionID + + // 获取会话数据 + sessionData, err := a.redis.Get(ctx, sessionKey).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil, fmt.Errorf("会话不存在或已过期") + } + return nil, nil, fmt.Errorf("获取会话数据失败: %w", err) + } + + var user system.AuthorizeUser + if err := json.Unmarshal([]byte(sessionData), &user); err != nil { + return nil, nil, fmt.Errorf("解析会话数据失败: %w", err) + } + + // 检查会话环境是否发生变化 + var riskResult *RiskCheckResult + if a.hasEnvironmentChanged(user, currentEnv) { + // 重新进行风险评估 + //riskResult, err = a.checkLoginRisk(ctx, user.Email, currentEnv) + //if err != nil { + // fmt.Printf("会话风险检查失败: %v\n", err) + //} else if riskResult.IsRisky { + // // 如果风险等级很高,可能需要重新验证 + // if riskResult.RiskLevel == RiskHigh { + // // 可以选择强制重新登录或要求额外验证 + // go a.sendSecurityAlert(user.Email, "会话环境异常", + // fmt.Sprintf("检测到会话环境发生变化: %s", strings.Join(riskResult.RiskReasons, ", "))) + // } + // + // // 更新风险会话信息 + // riskSessionKey := RiskSessionPrefix + sessionID + // riskData, _ := json.Marshal(riskResult) + // a.redis.Set(ctx, riskSessionKey, riskData, SessionTimeout) + //} + + // 更新用户环境信息 + user.IP = currentEnv.IP + user.OS = currentEnv.OS + user.Browser = currentEnv.Browser + + // 更新会话数据 + updatedSessionData, _ := json.Marshal(user) + a.redis.Set(ctx, sessionKey, updatedSessionData, SessionTimeout) + } + + // 延长会话有效期 + a.redis.Expire(ctx, sessionKey, SessionTimeout) + + return &user, riskResult, nil +} + +// 4. 辅助功能 + +// 从HTTP请求中提取登录环境信息 +//func ExtractLoginEnvironment(r *http.Request) LoginEnvironment { +// // 获取真实IP +// ip := getRealIP(r) +// +// // 解析User-Agent +// ua := useragent.Parse(r.UserAgent()) +// +// return LoginEnvironment{ +// IP: ip, +// OS: fmt.Sprintf("%s %s", ua.OS, ua.OSVersion), +// Browser: fmt.Sprintf("%s %s", ua.Name, ua.Version), +// } +//} + +// 获取真实IP地址 +func getRealIP(r *http.Request) string { + // 检查代理头 + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return strings.Split(ip, ",")[0] + } + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Forwarded-Proto"); ip != "" { + return ip + } + + // 从RemoteAddr中提取IP + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +// recordLoginLog 记录登录日志 +func (a *Auth) recordLoginLog(ctx context.Context, log *system.LoginLog) error { + return a.loginLogService.Create(ctx, log) +} + +// getHistoricalLoginEnvironments 获取历史登录环境 +func (a *Auth) getHistoricalLoginEnvironments(ctx context.Context, email string) ([]LoginEnvironment, error) { + rows, err := a.loginLogService.HistoricalLogin(ctx, email, time.Now().Add(-RiskCheckWindow)) + if err != nil { + return nil, err + } + + var envs []LoginEnvironment + for _, item := range rows { + envs = append(envs, LoginEnvironment{ + IP: item.Ip, + OS: item.Os, + Browser: item.Browser, + Location: "", + }) + } + + return envs, nil +} + +// isNewIP 检查是否为新IP +func (a *Auth) isNewIP(currentIP string, historicalEnvs []LoginEnvironment) bool { + for _, env := range historicalEnvs { + if env.IP == currentIP { + return false + } + } + return true +} + +// isNewBrowser 检查是否为新浏览器 +func (a *Auth) isNewBrowser(currentBrowser string, historicalEnvs []LoginEnvironment) bool { + for _, env := range historicalEnvs { + if env.Browser == currentBrowser { + return false + } + } + return true +} + +// isNewOS 检查是否为新操作系统 +func (a *Auth) isNewOS(currentOS string, historicalEnvs []LoginEnvironment) bool { + for _, env := range historicalEnvs { + if env.OS == currentOS { + return false + } + } + return true +} + +// isSignificantLocationChange 检查IP地理位置是否发生重大变化(简化实现) +func (a *Auth) isSignificantLocationChange(currentIP string, historicalEnvs []LoginEnvironment) bool { + // 这里可以集成IP地理位置服务,比如使用MaxMind GeoIP2 + // 简化实现:检查IP前两段是否相同 + currentSegments := strings.Split(currentIP, ".") + if len(currentSegments) < 2 { + return true + } + currentPrefix := currentSegments[0] + "." + currentSegments[1] + + for _, env := range historicalEnvs { + segments := strings.Split(env.IP, ".") + if len(segments) >= 2 { + prefix := segments[0] + "." + segments[1] + if prefix == currentPrefix { + return false // 找到相同网段的历史IP + } + } + } + + return true // 所有历史IP都与当前IP不在同一网段 +} + +// hasEnvironmentChanged 检查会话环境是否发生变化 +func (a *Auth) hasEnvironmentChanged(user system.AuthorizeUser, currentEnv LoginEnvironment) bool { + return user.IP != currentEnv.IP || + user.OS != currentEnv.OS || + user.Browser != currentEnv.Browser +} + +// sendSecurityAlert 发送安全警报(简化实现) +func (a *Auth) sendSecurityAlert(email, subject, message string) { + // 这里可以实现邮件发送、短信通知等 + fmt.Printf("安全警报 - 用户: %s, 主题: %s, 消息: %s\n", email, subject, message) +} + +// GetUserRisk 获取会话风险信息 +func (a *Auth) GetUserRisk(ctx context.Context, userID string) (*RiskCheckResult, error) { + riskSessionKey := RiskSessionPrefix + userID + riskData, err := a.redis.Get(ctx, riskSessionKey).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil // 没有风险记录 + } + return nil, fmt.Errorf("获取风险信息失败: %w", err) + } + + var result RiskCheckResult + if err := json.Unmarshal([]byte(riskData), &result); err != nil { + return nil, fmt.Errorf("解析风险信息失败: %w", err) + } + + return &result, nil +} diff --git a/internal/erpserver/service/v1/service.go b/internal/erpserver/service/v1/service.go index 4cce611..1e58cb7 100644 --- a/internal/erpserver/service/v1/service.go +++ b/internal/erpserver/service/v1/service.go @@ -2,6 +2,7 @@ package v1 import ( "context" + "time" "management/internal/erpserver/model/dto" "management/internal/erpserver/model/form" @@ -53,6 +54,7 @@ type UserService interface { All(ctx context.Context) ([]*system.User, error) List(ctx context.Context, q dto.SearchDto) ([]*system.User, int64, error) Get(ctx context.Context, id int32) (*system.User, error) + GetByEmail(ctx context.Context, email string) (*system.User, error) XmSelect(ctx context.Context) ([]*view.XmSelect, error) Login(ctx context.Context, req *form.Login) error @@ -64,6 +66,8 @@ type LoginLogService interface { LoginTime(ctx context.Context, email string) (dto.LoginTimeDto, error) LoginCount(ctx context.Context, email string) int64 + + HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*system.LoginLog, error) } type AuditLogService interface { diff --git a/internal/erpserver/service/v1/system/login_log.go b/internal/erpserver/service/v1/system/login_log.go index 78191b6..b3ef144 100644 --- a/internal/erpserver/service/v1/system/login_log.go +++ b/internal/erpserver/service/v1/system/login_log.go @@ -2,6 +2,7 @@ package system import ( "context" + "time" "management/internal/erpserver/model/dto" "management/internal/erpserver/model/system" @@ -63,3 +64,11 @@ func (s *loginLogService) LoginCount(ctx context.Context, email string) int64 { } return count } + +func (s *loginLogService) HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*system.LoginLog, error) { + logs, err := s.repo.HistoricalLogin(ctx, email, createdAt) + if err != nil { + return nil, err + } + return logs, nil +} diff --git a/internal/erpserver/service/v1/system/user.go b/internal/erpserver/service/v1/system/user.go index ff46cc4..826cea7 100644 --- a/internal/erpserver/service/v1/system/user.go +++ b/internal/erpserver/service/v1/system/user.go @@ -136,6 +136,10 @@ func (s *userService) Get(ctx context.Context, id int32) (*system.User, error) { return s.repo.Get(ctx, id) } +func (s *userService) GetByEmail(ctx context.Context, email string) (*system.User, error) { + return s.repo.GetByEmail(ctx, email) +} + func (s *userService) XmSelect(ctx context.Context) ([]*view.XmSelect, error) { all, err := s.repo.All(ctx) if err != nil || len(all) == 0 { @@ -154,7 +158,7 @@ func (s *userService) XmSelect(ctx context.Context) ([]*view.XmSelect, error) { func (s *userService) Login(ctx context.Context, req *form.Login) error { l := system.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer) - err := s.login(ctx, req) + _, err := s.login(ctx, req) if err != nil { if err := s.loginLogService.Create(ctx, l.SetMessage(err.Error())); err != nil { s.Log.Error(err.Error(), err, zap.Any("login_log", l)) @@ -168,36 +172,36 @@ func (s *userService) Login(ctx context.Context, req *form.Login) error { return nil } -func (s *userService) login(ctx context.Context, req *form.Login) error { +func (s *userService) login(ctx context.Context, req *form.Login) (*system.User, error) { user, err := s.repo.GetByEmail(ctx, req.Email) if err != nil { - return err + return nil, err } err = crypto.BcryptComparePassword(user.HashedPassword, req.Password+user.Salt) if err != nil { - return errors.New("账号或密码错误") + return nil, errors.New("账号或密码错误") } user.Role, err = s.roleService.Get(ctx, user.RoleID) if err != nil { - return err + return nil, err } if user.Role == nil || user.Role.ID == 0 { - return errors.New("账号没有配置角色, 请联系管理员") + return nil, errors.New("账号没有配置角色, 请联系管理员") } // 登陆成功 err = s.loginSuccess(ctx, user, req) if err != nil { - return err + return nil, err } - return nil + return user, nil } func (s *userService) loginSuccess(ctx context.Context, user *system.User, req *form.Login) error { - return s.Session.PutUser(ctx, know.StoreName, dto.AuthorizeUser{ + return s.Session.PutUser(ctx, know.StoreName, system.AuthorizeUser{ ID: user.ID, Uuid: user.Uuid, Email: user.Email, diff --git a/internal/erpserver/templ/auth/login.templ b/internal/erpserver/templ/auth/login.templ index 7c864f7..780e62e 100644 --- a/internal/erpserver/templ/auth/login.templ +++ b/internal/erpserver/templ/auth/login.templ @@ -130,6 +130,7 @@ templ Login(ctx context.Context) { }); }); } else { + $('#captcha').click(); loading.stop(function () { popup.failure(obj.msg); }); diff --git a/internal/erpserver/templ/auth/login_templ.go b/internal/erpserver/templ/auth/login_templ.go index 8bb794a..20cb16d 100644 --- a/internal/erpserver/templ/auth/login_templ.go +++ b/internal/erpserver/templ/auth/login_templ.go @@ -48,7 +48,7 @@ func Login(ctx context.Context) templ.Component { if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } - templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "';\n \t\t\t\t$.ajax({\n \t\t\t\t\turl: '/login',\n \t\t\t\t\ttype: 'post',\n \t\t\t\t\tdataType: 'json',\n \t\t\t\t\tdata: data.field,\n \t\t\t\t\tsuccess: function (obj) {\n \t\t\t\t\t\tif (obj.success) {\n \t\t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\t\tpopup.success(\"登录成功\", function () {\n \t\t\t\t\t\t\t\t\tlocation.href = \"/home.html\"\n \t\t\t\t\t\t\t\t});\n \t\t\t\t\t\t\t});\n \t\t\t\t\t\t} else {\n \t\t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\t\tpopup.failure(obj.msg);\n \t\t\t\t\t\t\t});\n \t\t\t\t\t\t}\n \t\t\t\t\t},\n \t\t\t\t\terror: function (ex) {\n \t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\tpopup.failure('网络异常,请刷新重试');\n \t\t\t\t\t\t});\n \t\t\t\t\t}\n \t\t\t\t});\n \t\t\t\treturn false;\n \t\t\t});\n \t\t})\n ") + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "';\n \t\t\t\t$.ajax({\n \t\t\t\t\turl: '/login',\n \t\t\t\t\ttype: 'post',\n \t\t\t\t\tdataType: 'json',\n \t\t\t\t\tdata: data.field,\n \t\t\t\t\tsuccess: function (obj) {\n \t\t\t\t\t\tif (obj.success) {\n \t\t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\t\tpopup.success(\"登录成功\", function () {\n \t\t\t\t\t\t\t\t\tlocation.href = \"/home.html\"\n \t\t\t\t\t\t\t\t});\n \t\t\t\t\t\t\t});\n \t\t\t\t\t\t} else {\n \t\t\t\t\t\t\t$('#captcha').click();\n \t\t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\t\tpopup.failure(obj.msg);\n \t\t\t\t\t\t\t});\n \t\t\t\t\t\t}\n \t\t\t\t\t},\n \t\t\t\t\terror: function (ex) {\n \t\t\t\t\t\tloading.stop(function () {\n \t\t\t\t\t\t\tpopup.failure('网络异常,请刷新重试');\n \t\t\t\t\t\t});\n \t\t\t\t\t}\n \t\t\t\t});\n \t\t\t\treturn false;\n \t\t\t});\n \t\t})\n ") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } diff --git a/internal/pkg/mid/audit_v3.go b/internal/pkg/mid/audit_v3.go index b38e4e3..9850e89 100644 --- a/internal/pkg/mid/audit_v3.go +++ b/internal/pkg/mid/audit_v3.go @@ -45,7 +45,7 @@ func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor) c, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - + if err := task.DistributeTaskConsumeAuditLog(c, payload, opts...); err != nil { log.Error("distribute task failed", err, zap.String("type", "audit"), diff --git a/internal/pkg/mid/mid.go b/internal/pkg/mid/mid.go index e09f30c..1244340 100644 --- a/internal/pkg/mid/mid.go +++ b/internal/pkg/mid/mid.go @@ -5,6 +5,7 @@ import ( "errors" "management/internal/erpserver/model/dto" + "management/internal/erpserver/model/system" "management/internal/pkg/sqldb" "github.com/a-h/templ" @@ -12,15 +13,15 @@ import ( type userKey struct{} -func setUser(ctx context.Context, usr dto.AuthorizeUser) context.Context { +func setUser(ctx context.Context, usr system.AuthorizeUser) context.Context { return context.WithValue(ctx, userKey{}, usr) } // GetUser returns the user from the context. -func GetUser(ctx context.Context) dto.AuthorizeUser { - v, ok := ctx.Value(userKey{}).(dto.AuthorizeUser) +func GetUser(ctx context.Context) system.AuthorizeUser { + v, ok := ctx.Value(userKey{}).(system.AuthorizeUser) if !ok { - return dto.AuthorizeUser{} + return system.AuthorizeUser{} } return v diff --git a/internal/pkg/session/session.go b/internal/pkg/session/session.go index 68bb118..77b8140 100644 --- a/internal/pkg/session/session.go +++ b/internal/pkg/session/session.go @@ -7,7 +7,7 @@ import ( "net/http" "time" - "management/internal/erpserver/model/dto" + "management/internal/erpserver/model/system" "management/internal/pkg/config" "github.com/alexedwards/scs/v2" @@ -19,8 +19,8 @@ var ErrNoSession = errors.New("session user not found") // Manager 抽象核心会话操作 type Manager interface { Load(next http.Handler) http.Handler - GetUser(ctx context.Context, key string) (dto.AuthorizeUser, error) - PutUser(ctx context.Context, key string, user dto.AuthorizeUser) error + GetUser(ctx context.Context, key string) (system.AuthorizeUser, error) + PutUser(ctx context.Context, key string, user system.AuthorizeUser) error RenewToken(ctx context.Context) error Destroy(ctx context.Context) error } @@ -60,20 +60,20 @@ func (s *SCSSession) Load(next http.Handler) http.Handler { return s.manager.LoadAndSave(next) } -func (s *SCSSession) GetUser(ctx context.Context, key string) (dto.AuthorizeUser, error) { +func (s *SCSSession) GetUser(ctx context.Context, key string) (system.AuthorizeUser, error) { data, ok := s.manager.Get(ctx, key).([]byte) if !ok || len(data) == 0 { - return dto.AuthorizeUser{}, ErrNoSession + return system.AuthorizeUser{}, ErrNoSession } - var user dto.AuthorizeUser + var user system.AuthorizeUser if err := json.Unmarshal(data, &user); err != nil { - return dto.AuthorizeUser{}, err + return system.AuthorizeUser{}, err } return user, nil } -func (s *SCSSession) PutUser(ctx context.Context, key string, user dto.AuthorizeUser) error { +func (s *SCSSession) PutUser(ctx context.Context, key string, user system.AuthorizeUser) error { data, err := json.Marshal(&user) if err != nil { return err