This commit is contained in:
2025-06-25 16:11:03 +08:00
parent b48d14a6fb
commit 4186cd0caf
16 changed files with 690 additions and 26 deletions

View File

@@ -0,0 +1,571 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"management/internal/erpserver/model/form"
"management/internal/erpserver/model/system"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/crypto"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/drhin/logger"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// 安全配置常量
const (
// MaxLoginAttempts 最大登录尝试次数
MaxLoginAttempts = 5
// LoginLockoutDuration 锁定时间
LoginLockoutDuration = 30 * time.Minute
// RiskCheckWindow 风险检查时间窗口
RiskCheckWindow = 24 * time.Hour
// SessionTimeout 会话超时时间
SessionTimeout = 2 * time.Hour
)
// Redis键前缀
const (
LoginAttemptsPrefix = "login_attempts:"
LoginLockPrefix = "login_lock:"
SessionPrefix = "session:"
RiskSessionPrefix = "risk_session:"
)
// RiskLevel 风险等级
type RiskLevel string
const (
RiskLow RiskLevel = "low"
RiskMedium RiskLevel = "medium"
RiskHigh RiskLevel = "high"
)
// RiskCheckResult 风险检查结果
type RiskCheckResult struct {
IsRisky bool `json:"is_risky"`
RiskLevel RiskLevel `json:"risk_level"`
RiskReasons []string `json:"risk_reasons"`
RequiresMFA bool `json:"requires_mfa"`
}
// LoginEnvironment 登录环境信息
type LoginEnvironment struct {
IP string `json:"ip"`
OS string `json:"os"`
Browser string `json:"browser"`
Location string `json:"location,omitempty"`
}
// Auth 安全管理器
type Auth struct {
log *logger.Logger
redis *redis.Client
sm session.Manager
userService v1.UserService
roleService v1.RoleService
loginLogService v1.LoginLogService
}
// NewAuth 创建安全管理器
func NewAuth(
log *logger.Logger,
redis *redis.Client,
sm session.Manager,
userService v1.UserService,
roleService v1.RoleService,
loginLogService v1.LoginLogService,
) *Auth {
return &Auth{
log: log,
redis: redis,
sm: sm,
userService: userService,
roleService: roleService,
loginLogService: loginLogService,
}
}
func (a *Auth) Authenticate(ctx context.Context, req form.Login) (*RiskCheckResult, error) {
l := system.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer)
locked, duration, err := a.isAccountLocked(ctx, req.Email)
if err != nil {
return nil, err
}
if locked {
return nil, fmt.Errorf("账户已被锁定,请在 %v 后重试", duration.Round(time.Minute))
}
user, err := a.validateUser(ctx, req.Email, req.Password)
if err != nil {
// 记录登录失败
err = a.recordLoginFailure(ctx, l.SetMessage("用户名或密码错误"))
if err != nil {
a.log.Error(err.Error(), err, zap.Any("login_log", l))
}
// 获取剩余尝试次数
remaining, err := a.getRemainingAttempts(ctx, req.Email)
if err == nil && remaining > 0 {
return nil, fmt.Errorf("用户名或密码错误,还有 %d 次尝试机会", remaining)
}
return nil, errors.New("用户名或密码错误")
}
if err := a.successfulLogin(ctx, user.Uuid, l.SetOk("校验成功")); err != nil {
return nil, err
}
// 获取风险评估结果
risk, err := a.GetUserRisk(ctx, user.Uuid.String())
if err != nil {
return nil, err
}
// 如果存在风险,在响应中包含风险信息
if risk != nil && risk.IsRisky {
if risk.RequiresMFA {
risk.RiskReasons = append(risk.RiskReasons, "检测到异常登录环境,建议启用多重身份验证")
} else {
risk.RiskReasons = append(risk.RiskReasons, "登录成功,检测到新的登录环境")
}
}
// 设置会话Cookie
au := system.NewAuthorizeUser(user, req.Os, req.Ip, req.Browser)
if err := a.sm.PutUser(ctx, know.StoreName, au); err != nil {
return nil, err
}
return risk, nil
}
func (a *Auth) validateUser(ctx context.Context, email, password string) (*system.User, error) {
user, err := a.userService.GetByEmail(ctx, email)
if err != nil {
return nil, err
}
if err := crypto.BcryptComparePassword(user.HashedPassword, password+user.Salt); err != nil {
return nil, errors.New("账号或密码错误")
}
user.Role, err = a.roleService.Get(ctx, user.RoleID)
if err != nil {
return nil, err
}
if user.Role == nil || user.Role.ID == 0 {
return nil, errors.New("账号没有配置角色, 请联系管理员")
}
return user, nil
}
// 1. 登录失败次数限制功能
// isAccountLocked 检查是否被锁定
func (a *Auth) isAccountLocked(ctx context.Context, email string) (bool, time.Duration, error) {
lockKey := LoginLockPrefix + email
// 检查是否存在锁定记录
ttl, err := a.redis.TTL(ctx, lockKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return false, 0, nil // 没有锁定记录
}
return false, 0, fmt.Errorf("检查锁定状态失败: %w", err)
}
if ttl > 0 {
return true, ttl, nil // 仍在锁定期间
}
return false, 0, nil
}
// recordLoginFailure 记录登录失败
func (a *Auth) recordLoginFailure(ctx context.Context, log *system.LoginLog) error {
attemptsKey := LoginAttemptsPrefix + log.Email
// 增加失败次数
attempts, err := a.redis.Incr(ctx, attemptsKey).Result()
if err != nil {
return fmt.Errorf("记录登录失败次数失败: %w", err)
}
// 设置过期时间(首次失败时)
if attempts == 1 {
a.redis.Expire(ctx, attemptsKey, LoginLockoutDuration)
}
// 记录登录日志
if err := a.loginLogService.Create(ctx, log); err != nil {
// 日志记录失败不影响主流程,只记录错误
fmt.Printf("记录登录日志失败: %v\n", err)
}
// 如果达到最大尝试次数,锁定账户
if attempts >= MaxLoginAttempts {
lockKey := LoginLockPrefix + log.Email
err := a.redis.Set(ctx, lockKey, "locked", LoginLockoutDuration).Err()
if err != nil {
return fmt.Errorf("锁定账户失败: %w", err)
}
// 发送安全警告通知(这里可以扩展为邮件通知等)
go a.sendSecurityAlert(log.Email, "账户被锁定", fmt.Sprintf("由于连续%d次登录失败您的账户已被锁定%v", MaxLoginAttempts, LoginLockoutDuration))
}
return nil
}
// clearLoginFailures 清除登录失败记录
func (a *Auth) clearLoginFailures(ctx context.Context, email string) error {
attemptsKey := LoginAttemptsPrefix + email
return a.redis.Del(ctx, attemptsKey).Err()
}
// getRemainingAttempts 获取剩余登录尝试次数
func (a *Auth) getRemainingAttempts(ctx context.Context, email string) (int, error) {
attemptsKey := LoginAttemptsPrefix + email
attempts, err := a.redis.Get(ctx, attemptsKey).Int()
if err != nil {
if errors.Is(err, redis.Nil) {
return MaxLoginAttempts, nil // 没有失败记录
}
return 0, fmt.Errorf("获取登录尝试次数失败: %w", err)
}
remaining := MaxLoginAttempts - attempts
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// 2. 登录风险检测功能
// checkLoginRisk 检查登录风险
func (a *Auth) checkLoginRisk(ctx context.Context, email, ip, os, browser string) (*RiskCheckResult, error) {
result := &RiskCheckResult{
IsRisky: false,
RiskLevel: RiskLow,
RiskReasons: []string{},
RequiresMFA: false,
}
// 获取用户历史登录环境
historicalEnvs, err := a.getHistoricalLoginEnvironments(ctx, email)
if err != nil {
return nil, fmt.Errorf("获取历史登录环境失败: %w", err)
}
// 如果是首次登录,风险较低
if len(historicalEnvs) == 0 {
result.RiskLevel = RiskMedium
result.RiskReasons = append(result.RiskReasons, "首次登录")
return result, nil
}
// 检查IP风险
if a.isNewIP(ip, historicalEnvs) {
result.IsRisky = true
result.RiskReasons = append(result.RiskReasons, "检测到新的IP地址")
// 检查IP地理位置变化
if a.isSignificantLocationChange(ip, historicalEnvs) {
result.RiskLevel = RiskHigh
result.RiskReasons = append(result.RiskReasons, "IP地理位置发生重大变化")
result.RequiresMFA = true
} else {
result.RiskLevel = RiskMedium
}
}
// 检查浏览器/设备风险
if a.isNewBrowser(browser, historicalEnvs) {
result.IsRisky = true
result.RiskReasons = append(result.RiskReasons, "检测到新的浏览器或设备")
if result.RiskLevel == RiskLow {
result.RiskLevel = RiskMedium
}
}
// 检查操作系统风险
if a.isNewOS(os, historicalEnvs) {
result.IsRisky = true
result.RiskReasons = append(result.RiskReasons, "检测到新的操作系统")
if result.RiskLevel == RiskLow {
result.RiskLevel = RiskMedium
}
}
// 如果存在多个风险因素,提升风险等级
if len(result.RiskReasons) >= 2 {
result.RiskLevel = RiskHigh
result.RequiresMFA = true
}
return result, nil
}
// successfulLogin 登录成功处理
func (a *Auth) successfulLogin(ctx context.Context, uuid uuid.UUID, log *system.LoginLog) error {
// 清除登录失败记录
if err := a.clearLoginFailures(ctx, log.Email); err != nil {
return fmt.Errorf("清除登录失败记录失败: %w", err)
}
// 记录登录成功日志
if err := a.recordLoginLog(ctx, log); err != nil {
fmt.Printf("记录登录日志失败: %v\n", err)
}
// 检查登录风险
riskResult, err := a.checkLoginRisk(ctx, log.Email, log.Ip, log.Os, log.Browser)
if err != nil {
fmt.Printf("风险检查失败: %v\n", err)
return nil // 风险检查失败不影响正常登录
}
// 如果存在风险,记录风险会话
if riskResult.IsRisky {
riskSessionKey := RiskSessionPrefix + uuid.String()
riskData, err := json.Marshal(riskResult)
if err != nil {
return err
}
if err := a.redis.Set(ctx, riskSessionKey, riskData, SessionTimeout).Err(); err != nil {
return err
}
// 发送安全提醒
go a.sendSecurityAlert(log.Email, "检测到异常登录",
fmt.Sprintf("风险等级: %s, 风险原因: %s", riskResult.RiskLevel, strings.Join(riskResult.RiskReasons, ", ")))
}
return nil
}
// 3. 会话安全检查
// ValidateSession 验证会话并检查安全性
func (a *Auth) ValidateSession(ctx context.Context, sessionID string, currentEnv LoginEnvironment) (*system.AuthorizeUser, *RiskCheckResult, error) {
sessionKey := SessionPrefix + sessionID
// 获取会话数据
sessionData, err := a.redis.Get(ctx, sessionKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil, fmt.Errorf("会话不存在或已过期")
}
return nil, nil, fmt.Errorf("获取会话数据失败: %w", err)
}
var user system.AuthorizeUser
if err := json.Unmarshal([]byte(sessionData), &user); err != nil {
return nil, nil, fmt.Errorf("解析会话数据失败: %w", err)
}
// 检查会话环境是否发生变化
var riskResult *RiskCheckResult
if a.hasEnvironmentChanged(user, currentEnv) {
// 重新进行风险评估
//riskResult, err = a.checkLoginRisk(ctx, user.Email, currentEnv)
//if err != nil {
// fmt.Printf("会话风险检查失败: %v\n", err)
//} else if riskResult.IsRisky {
// // 如果风险等级很高,可能需要重新验证
// if riskResult.RiskLevel == RiskHigh {
// // 可以选择强制重新登录或要求额外验证
// go a.sendSecurityAlert(user.Email, "会话环境异常",
// fmt.Sprintf("检测到会话环境发生变化: %s", strings.Join(riskResult.RiskReasons, ", ")))
// }
//
// // 更新风险会话信息
// riskSessionKey := RiskSessionPrefix + sessionID
// riskData, _ := json.Marshal(riskResult)
// a.redis.Set(ctx, riskSessionKey, riskData, SessionTimeout)
//}
// 更新用户环境信息
user.IP = currentEnv.IP
user.OS = currentEnv.OS
user.Browser = currentEnv.Browser
// 更新会话数据
updatedSessionData, _ := json.Marshal(user)
a.redis.Set(ctx, sessionKey, updatedSessionData, SessionTimeout)
}
// 延长会话有效期
a.redis.Expire(ctx, sessionKey, SessionTimeout)
return &user, riskResult, nil
}
// 4. 辅助功能
// 从HTTP请求中提取登录环境信息
//func ExtractLoginEnvironment(r *http.Request) LoginEnvironment {
// // 获取真实IP
// ip := getRealIP(r)
//
// // 解析User-Agent
// ua := useragent.Parse(r.UserAgent())
//
// return LoginEnvironment{
// IP: ip,
// OS: fmt.Sprintf("%s %s", ua.OS, ua.OSVersion),
// Browser: fmt.Sprintf("%s %s", ua.Name, ua.Version),
// }
//}
// 获取真实IP地址
func getRealIP(r *http.Request) string {
// 检查代理头
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-Proto"); ip != "" {
return ip
}
// 从RemoteAddr中提取IP
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// recordLoginLog 记录登录日志
func (a *Auth) recordLoginLog(ctx context.Context, log *system.LoginLog) error {
return a.loginLogService.Create(ctx, log)
}
// getHistoricalLoginEnvironments 获取历史登录环境
func (a *Auth) getHistoricalLoginEnvironments(ctx context.Context, email string) ([]LoginEnvironment, error) {
rows, err := a.loginLogService.HistoricalLogin(ctx, email, time.Now().Add(-RiskCheckWindow))
if err != nil {
return nil, err
}
var envs []LoginEnvironment
for _, item := range rows {
envs = append(envs, LoginEnvironment{
IP: item.Ip,
OS: item.Os,
Browser: item.Browser,
Location: "",
})
}
return envs, nil
}
// isNewIP 检查是否为新IP
func (a *Auth) isNewIP(currentIP string, historicalEnvs []LoginEnvironment) bool {
for _, env := range historicalEnvs {
if env.IP == currentIP {
return false
}
}
return true
}
// isNewBrowser 检查是否为新浏览器
func (a *Auth) isNewBrowser(currentBrowser string, historicalEnvs []LoginEnvironment) bool {
for _, env := range historicalEnvs {
if env.Browser == currentBrowser {
return false
}
}
return true
}
// isNewOS 检查是否为新操作系统
func (a *Auth) isNewOS(currentOS string, historicalEnvs []LoginEnvironment) bool {
for _, env := range historicalEnvs {
if env.OS == currentOS {
return false
}
}
return true
}
// isSignificantLocationChange 检查IP地理位置是否发生重大变化简化实现
func (a *Auth) isSignificantLocationChange(currentIP string, historicalEnvs []LoginEnvironment) bool {
// 这里可以集成IP地理位置服务比如使用MaxMind GeoIP2
// 简化实现检查IP前两段是否相同
currentSegments := strings.Split(currentIP, ".")
if len(currentSegments) < 2 {
return true
}
currentPrefix := currentSegments[0] + "." + currentSegments[1]
for _, env := range historicalEnvs {
segments := strings.Split(env.IP, ".")
if len(segments) >= 2 {
prefix := segments[0] + "." + segments[1]
if prefix == currentPrefix {
return false // 找到相同网段的历史IP
}
}
}
return true // 所有历史IP都与当前IP不在同一网段
}
// hasEnvironmentChanged 检查会话环境是否发生变化
func (a *Auth) hasEnvironmentChanged(user system.AuthorizeUser, currentEnv LoginEnvironment) bool {
return user.IP != currentEnv.IP ||
user.OS != currentEnv.OS ||
user.Browser != currentEnv.Browser
}
// sendSecurityAlert 发送安全警报(简化实现)
func (a *Auth) sendSecurityAlert(email, subject, message string) {
// 这里可以实现邮件发送、短信通知等
fmt.Printf("安全警报 - 用户: %s, 主题: %s, 消息: %s\n", email, subject, message)
}
// GetUserRisk 获取会话风险信息
func (a *Auth) GetUserRisk(ctx context.Context, userID string) (*RiskCheckResult, error) {
riskSessionKey := RiskSessionPrefix + userID
riskData, err := a.redis.Get(ctx, riskSessionKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil // 没有风险记录
}
return nil, fmt.Errorf("获取风险信息失败: %w", err)
}
var result RiskCheckResult
if err := json.Unmarshal([]byte(riskData), &result); err != nil {
return nil, fmt.Errorf("解析风险信息失败: %w", err)
}
return &result, nil
}

View File

@@ -2,6 +2,7 @@ package v1
import (
"context"
"time"
"management/internal/erpserver/model/dto"
"management/internal/erpserver/model/form"
@@ -53,6 +54,7 @@ type UserService interface {
All(ctx context.Context) ([]*system.User, error)
List(ctx context.Context, q dto.SearchDto) ([]*system.User, int64, error)
Get(ctx context.Context, id int32) (*system.User, error)
GetByEmail(ctx context.Context, email string) (*system.User, error)
XmSelect(ctx context.Context) ([]*view.XmSelect, error)
Login(ctx context.Context, req *form.Login) error
@@ -64,6 +66,8 @@ type LoginLogService interface {
LoginTime(ctx context.Context, email string) (dto.LoginTimeDto, error)
LoginCount(ctx context.Context, email string) int64
HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*system.LoginLog, error)
}
type AuditLogService interface {

View File

@@ -2,6 +2,7 @@ package system
import (
"context"
"time"
"management/internal/erpserver/model/dto"
"management/internal/erpserver/model/system"
@@ -63,3 +64,11 @@ func (s *loginLogService) LoginCount(ctx context.Context, email string) int64 {
}
return count
}
func (s *loginLogService) HistoricalLogin(ctx context.Context, email string, createdAt time.Time) ([]*system.LoginLog, error) {
logs, err := s.repo.HistoricalLogin(ctx, email, createdAt)
if err != nil {
return nil, err
}
return logs, nil
}

View File

@@ -136,6 +136,10 @@ func (s *userService) Get(ctx context.Context, id int32) (*system.User, error) {
return s.repo.Get(ctx, id)
}
func (s *userService) GetByEmail(ctx context.Context, email string) (*system.User, error) {
return s.repo.GetByEmail(ctx, email)
}
func (s *userService) XmSelect(ctx context.Context) ([]*view.XmSelect, error) {
all, err := s.repo.All(ctx)
if err != nil || len(all) == 0 {
@@ -154,7 +158,7 @@ func (s *userService) XmSelect(ctx context.Context) ([]*view.XmSelect, error) {
func (s *userService) Login(ctx context.Context, req *form.Login) error {
l := system.NewLoginLog(req.Email, req.Os, req.Ip, req.Browser, req.Url, req.Referrer)
err := s.login(ctx, req)
_, err := s.login(ctx, req)
if err != nil {
if err := s.loginLogService.Create(ctx, l.SetMessage(err.Error())); err != nil {
s.Log.Error(err.Error(), err, zap.Any("login_log", l))
@@ -168,36 +172,36 @@ func (s *userService) Login(ctx context.Context, req *form.Login) error {
return nil
}
func (s *userService) login(ctx context.Context, req *form.Login) error {
func (s *userService) login(ctx context.Context, req *form.Login) (*system.User, error) {
user, err := s.repo.GetByEmail(ctx, req.Email)
if err != nil {
return err
return nil, err
}
err = crypto.BcryptComparePassword(user.HashedPassword, req.Password+user.Salt)
if err != nil {
return errors.New("账号或密码错误")
return nil, errors.New("账号或密码错误")
}
user.Role, err = s.roleService.Get(ctx, user.RoleID)
if err != nil {
return err
return nil, err
}
if user.Role == nil || user.Role.ID == 0 {
return errors.New("账号没有配置角色, 请联系管理员")
return nil, errors.New("账号没有配置角色, 请联系管理员")
}
// 登陆成功
err = s.loginSuccess(ctx, user, req)
if err != nil {
return err
return nil, err
}
return nil
return user, nil
}
func (s *userService) loginSuccess(ctx context.Context, user *system.User, req *form.Login) error {
return s.Session.PutUser(ctx, know.StoreName, dto.AuthorizeUser{
return s.Session.PutUser(ctx, know.StoreName, system.AuthorizeUser{
ID: user.ID,
Uuid: user.Uuid,
Email: user.Email,