2025-03-21 11:05:42 +08:00

127 lines
2.7 KiB
Go

package redis
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"time"
"management/internal/config"
"github.com/redis/go-redis/v9"
)
var (
engine *redis.Client
ErrRedisKeyNotFound = errors.New("redis key not found")
)
// func GetRedis() *redis.Client {
// return rd
// }
func Init() error {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.File.Redis.Host, config.File.Redis.Port),
Password: config.File.Redis.Password,
DB: config.File.Redis.DB,
})
_, err := rdb.Ping(context.Background()).Result()
if err != nil {
return err
}
engine = rdb
return nil
}
func Encode(a any) ([]byte, error) {
var b bytes.Buffer
if err := gob.NewEncoder(&b).Encode(a); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Set 设置值
func Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return engine.Set(ctx, key, value, expiration).Err()
}
// Del 删除键值
func Del(ctx context.Context, keys ...string) error {
return engine.Del(ctx, keys...).Err()
}
// Get 获取值
func Get(ctx context.Context, key string) (string, error) {
val, err := engine.Get(ctx, key).Result()
if err == redis.Nil {
return "", ErrRedisKeyNotFound
} else if err != nil {
return "", fmt.Errorf("cannot get value with:[%s]: %v", key, err)
} else {
return val, nil
}
}
// GetBytes 获取值
func GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := engine.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, ErrRedisKeyNotFound
} else if err != nil {
return nil, fmt.Errorf("cannot get value with:[%s]: %v", key, err)
} else {
return val, nil
}
}
func Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
return engine.Scan(ctx, cursor, match, count)
}
func Keys(ctx context.Context, pattern string) ([]string, error) {
return engine.Keys(ctx, pattern).Result()
}
func ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error) {
all, err := engine.Keys(ctx, pattern).Result()
if err != nil {
return nil, 0, err
}
count := len(all)
if count == 0 {
return nil, 0, err
}
// 使用SCAN命令分页获取键
cursor := uint64(0)
var keys []string
for {
var scanResult []string
var err error
scanResult, cursor, err = engine.Scan(ctx, cursor, pattern, int64(pageSize)).Result()
if err != nil {
return nil, count, err
}
keys = append(keys, scanResult...)
if cursor == 0 {
break
}
}
startIndex := (pageID - 1) * pageSize
endIndex := startIndex + pageSize
if startIndex >= len(keys) {
return nil, count, nil
}
if endIndex > len(keys) {
endIndex = len(keys)
}
return keys[startIndex:endIndex], count, nil
}