v3_1
This commit is contained in:
@@ -56,11 +56,22 @@ WHERE (NOT @is_status::Boolean OR status = @status)
|
||||
AND (@email::text = '' OR email ILIKE '%' || @email || '%');
|
||||
|
||||
-- name: ListSysUserCondition :many
|
||||
SELECT id, uuid, email, username, avatar, gender, department_id, role_id, status, change_password_at, created_at, updated_at,
|
||||
(SELECT name FROM sys_department WHERE ID = sys_user.department_id) AS department_name,
|
||||
(SELECT display_name
|
||||
FROM sys_role
|
||||
WHERE id = sys_user.role_id) AS role_name
|
||||
SELECT id,
|
||||
uuid,
|
||||
email,
|
||||
username,
|
||||
avatar,
|
||||
gender,
|
||||
department_id,
|
||||
role_id,
|
||||
status,
|
||||
change_password_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
COALESCE((SELECT name FROM sys_department WHERE ID = sys_user.department_id), '') AS department_name,
|
||||
COALESCE((SELECT display_name
|
||||
FROM sys_role
|
||||
WHERE id = sys_user.role_id), '') AS role_name
|
||||
FROM sys_user
|
||||
WHERE (NOT @is_status::Boolean OR sys_user.status = @status)
|
||||
AND (NOT @is_id::Boolean OR sys_user.id = @id)
|
||||
|
||||
@@ -276,11 +276,22 @@ func (q *Queries) ListSysUserByIds(ctx context.Context, dollar_1 []int32) ([]*Sy
|
||||
}
|
||||
|
||||
const listSysUserCondition = `-- name: ListSysUserCondition :many
|
||||
SELECT id, uuid, email, username, avatar, gender, department_id, role_id, status, change_password_at, created_at, updated_at,
|
||||
(SELECT name FROM sys_department WHERE ID = sys_user.department_id) AS department_name,
|
||||
(SELECT display_name
|
||||
FROM sys_role
|
||||
WHERE id = sys_user.role_id) AS role_name
|
||||
SELECT id,
|
||||
uuid,
|
||||
email,
|
||||
username,
|
||||
avatar,
|
||||
gender,
|
||||
department_id,
|
||||
role_id,
|
||||
status,
|
||||
change_password_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
COALESCE((SELECT name FROM sys_department WHERE ID = sys_user.department_id), '') AS department_name,
|
||||
COALESCE((SELECT display_name
|
||||
FROM sys_role
|
||||
WHERE id = sys_user.role_id), '') AS role_name
|
||||
FROM sys_user
|
||||
WHERE (NOT $1::Boolean OR sys_user.status = $2)
|
||||
AND (NOT $3::Boolean OR sys_user.id = $4)
|
||||
@@ -303,20 +314,20 @@ type ListSysUserConditionParams struct {
|
||||
}
|
||||
|
||||
type ListSysUserConditionRow struct {
|
||||
ID int32 `json:"id"`
|
||||
Uuid uuid.UUID `json:"uuid"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar"`
|
||||
Gender int32 `json:"gender"`
|
||||
DepartmentID int32 `json:"department_id"`
|
||||
RoleID int32 `json:"role_id"`
|
||||
Status int32 `json:"status"`
|
||||
ChangePasswordAt time.Time `json:"change_password_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DepartmentName string `json:"department_name"`
|
||||
RoleName string `json:"role_name"`
|
||||
ID int32 `json:"id"`
|
||||
Uuid uuid.UUID `json:"uuid"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar"`
|
||||
Gender int32 `json:"gender"`
|
||||
DepartmentID int32 `json:"department_id"`
|
||||
RoleID int32 `json:"role_id"`
|
||||
Status int32 `json:"status"`
|
||||
ChangePasswordAt time.Time `json:"change_password_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DepartmentName interface{} `json:"department_name"`
|
||||
RoleName interface{} `json:"role_name"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListSysUserCondition(ctx context.Context, arg *ListSysUserConditionParams) ([]*ListSysUserConditionRow, error) {
|
||||
|
||||
@@ -3,19 +3,23 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/erpserver/model/view"
|
||||
"management/internal/pkg/convertor"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/redis"
|
||||
)
|
||||
|
||||
type DepartmentBiz interface {
|
||||
Create(ctx context.Context, arg *db.CreateSysDepartmentParams) (*db.SysDepartment, error)
|
||||
Update(ctx context.Context, arg *db.UpdateSysDepartmentParams) (*db.SysDepartment, error)
|
||||
Create(ctx context.Context, req *form.Department) error
|
||||
Update(ctx context.Context, req *form.Department) error
|
||||
All(ctx context.Context) ([]*db.SysDepartment, error)
|
||||
List(ctx context.Context, q dto.SearchDto) ([]*db.SysDepartment, int64, error)
|
||||
Get(ctx context.Context, id int32) (*db.SysDepartment, error)
|
||||
@@ -96,12 +100,79 @@ func (b *departmentBiz) Get(ctx context.Context, id int32) (*db.SysDepartment, e
|
||||
return b.store.GetSysDepartment(ctx, id)
|
||||
}
|
||||
|
||||
func (b *departmentBiz) Create(ctx context.Context, arg *db.CreateSysDepartmentParams) (*db.SysDepartment, error) {
|
||||
return b.store.CreateSysDepartment(ctx, arg)
|
||||
func (b *departmentBiz) Create(ctx context.Context, req *form.Department) error {
|
||||
parent := &db.SysDepartment{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
if *req.ParentID > 0 {
|
||||
var err error
|
||||
parent, err = b.store.GetSysDepartment(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return errors.New("父级节点错误")
|
||||
}
|
||||
}
|
||||
|
||||
var order int32 = 6666
|
||||
if *req.Sort > 0 {
|
||||
order = *req.Sort
|
||||
}
|
||||
|
||||
arg := &db.CreateSysDepartmentParams{
|
||||
Name: req.Name,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Status: *req.Status,
|
||||
Sort: order,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err := b.store.CreateSysDepartment(ctx, arg)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return errors.New("部门已存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *departmentBiz) Update(ctx context.Context, arg *db.UpdateSysDepartmentParams) (*db.SysDepartment, error) {
|
||||
return b.store.UpdateSysDepartment(ctx, arg)
|
||||
func (b *departmentBiz) Update(ctx context.Context, req *form.Department) error {
|
||||
parent := &db.SysDepartment{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
if *req.ParentID > 0 {
|
||||
var err error
|
||||
parent, err = b.store.GetSysDepartment(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return errors.New("父级节点错误")
|
||||
}
|
||||
}
|
||||
|
||||
depart, err := b.store.GetSysDepartment(ctx, *req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var order int32 = 6666
|
||||
if *req.Sort > 0 {
|
||||
order = *req.Sort
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysDepartmentParams{
|
||||
ID: depart.ID,
|
||||
Name: req.Name,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Status: *req.Status,
|
||||
Sort: order,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = b.store.UpdateSysDepartment(ctx, arg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *departmentBiz) Refresh(ctx context.Context) ([]*db.SysDepartment, error) {
|
||||
|
||||
@@ -3,19 +3,24 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/erpserver/model/view"
|
||||
"management/internal/pkg/convertor"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/redis"
|
||||
)
|
||||
|
||||
type RoleBiz interface {
|
||||
Create(ctx context.Context, arg *db.CreateSysRoleParams) (*db.SysRole, error)
|
||||
Update(ctx context.Context, arg *db.UpdateSysRoleParams) (*db.SysRole, error)
|
||||
Create(ctx context.Context, req *form.Role) error
|
||||
Update(ctx context.Context, req *form.Role) error
|
||||
CreateOrUpdate(ctx context.Context, req *form.Role) error
|
||||
All(ctx context.Context) ([]*db.SysRole, error)
|
||||
List(ctx context.Context, q dto.SearchDto) ([]*db.SysRole, int64, error)
|
||||
Get(ctx context.Context, id int32) (*db.SysRole, error)
|
||||
@@ -44,12 +49,140 @@ func NewRole(store db.Store, redis redis.IRedis) *roleBiz {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *roleBiz) Create(ctx context.Context, arg *db.CreateSysRoleParams) (*db.SysRole, error) {
|
||||
return b.store.CreateSysRole(ctx, arg)
|
||||
func (b *roleBiz) Create(ctx context.Context, req *form.Role) error {
|
||||
parent := &db.SysRole{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
if *req.ParentID > 0 {
|
||||
var err error
|
||||
parent, err = b.store.GetSysRole(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return errors.New("父级节点错误")
|
||||
}
|
||||
}
|
||||
|
||||
var order int32 = 6666
|
||||
if *req.Sort > 0 {
|
||||
order = *req.Sort
|
||||
}
|
||||
|
||||
arg := &db.CreateSysRoleParams{
|
||||
Name: req.Name,
|
||||
DisplayName: req.DisplayName,
|
||||
Vip: false,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Status: *req.Status,
|
||||
Sort: order,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err := b.store.CreateSysRole(ctx, arg)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return errors.New("角色名称已存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *roleBiz) Update(ctx context.Context, arg *db.UpdateSysRoleParams) (*db.SysRole, error) {
|
||||
return b.store.UpdateSysRole(ctx, arg)
|
||||
func (b *roleBiz) Update(ctx context.Context, req *form.Role) error {
|
||||
parent := &db.SysRole{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
if *req.ParentID > 0 {
|
||||
var err error
|
||||
parent, err = b.store.GetSysRole(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return errors.New("父级节点错误")
|
||||
}
|
||||
}
|
||||
|
||||
role, err := b.store.GetSysRole(ctx, *req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var order int32 = 6666
|
||||
if *req.Sort > 0 {
|
||||
order = *req.Sort
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysRoleParams{
|
||||
ID: role.ID,
|
||||
DisplayName: req.DisplayName,
|
||||
Status: *req.Status,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Sort: order,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = b.store.UpdateSysRole(ctx, arg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *roleBiz) CreateOrUpdate(ctx context.Context, req *form.Role) error {
|
||||
parent := &db.SysRole{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
if *req.ParentID > 0 {
|
||||
var err error
|
||||
parent, err = b.store.GetSysRole(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return errors.New("父级节点错误")
|
||||
}
|
||||
}
|
||||
|
||||
var order int32 = 6666
|
||||
if *req.Sort > 0 {
|
||||
order = *req.Sort
|
||||
}
|
||||
|
||||
if *req.ID > 0 {
|
||||
role, err := b.store.GetSysRole(ctx, *req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysRoleParams{
|
||||
ID: role.ID,
|
||||
DisplayName: req.DisplayName,
|
||||
Status: *req.Status,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Sort: order,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = b.store.UpdateSysRole(ctx, arg)
|
||||
return err
|
||||
} else {
|
||||
arg := &db.CreateSysRoleParams{
|
||||
Name: req.Name,
|
||||
DisplayName: req.DisplayName,
|
||||
Vip: false,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)),
|
||||
Status: *req.Status,
|
||||
Sort: order,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err := b.store.CreateSysRole(ctx, arg)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return errors.New("角色名称已存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *roleBiz) All(ctx context.Context) ([]*db.SysRole, error) {
|
||||
|
||||
@@ -9,17 +9,20 @@ import (
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/model/req"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/erpserver/model/view"
|
||||
"management/internal/pkg/crypto"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/rand"
|
||||
"management/internal/pkg/session"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserBiz 定义处理用户请求所需的方法.
|
||||
type UserBiz interface {
|
||||
Create(ctx context.Context, req *db.CreateSysUserParams) (*db.SysUser, error)
|
||||
Update(ctx context.Context, req *db.UpdateSysUserParams) (*db.SysUser, error)
|
||||
Create(ctx context.Context, req *form.User) error
|
||||
Update(ctx context.Context, req *form.User) error
|
||||
All(ctx context.Context) ([]*db.SysUser, error)
|
||||
List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error)
|
||||
Get(ctx context.Context, id int32) (*db.SysUser, error)
|
||||
@@ -31,7 +34,7 @@ type UserBiz interface {
|
||||
|
||||
// UserExpansion 定义用户操作的扩展方法.
|
||||
type UserExpansion interface {
|
||||
Login(ctx context.Context, req *req.Login) error
|
||||
Login(ctx context.Context, req *form.Login) error
|
||||
}
|
||||
|
||||
// userBiz 是 UserBiz 接口的实现.
|
||||
@@ -50,12 +53,75 @@ func NewUser(store db.Store, session session.ISession) *userBiz {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *userBiz) Create(ctx context.Context, req *db.CreateSysUserParams) (*db.SysUser, error) {
|
||||
return b.store.CreateSysUser(ctx, req)
|
||||
func (b *userBiz) Create(ctx context.Context, req *form.User) error {
|
||||
salt, err := rand.String(10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashedPassword, err := crypto.BcryptHashPassword(req.Password + salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arg := &db.CreateSysUserParams{
|
||||
Uuid: uuid.Must(uuid.NewV7()),
|
||||
Email: req.Email,
|
||||
Username: req.Username,
|
||||
HashedPassword: hashedPassword,
|
||||
Salt: salt,
|
||||
Avatar: req.Avatar,
|
||||
Gender: req.Gender,
|
||||
DepartmentID: req.DepartmentID,
|
||||
RoleID: req.RoleID,
|
||||
Status: *req.Status,
|
||||
ChangePasswordAt: initTime,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = b.store.CreateSysUser(ctx, arg)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
return errors.New("用户已经存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *userBiz) Update(ctx context.Context, req *db.UpdateSysUserParams) (*db.SysUser, error) {
|
||||
return b.store.UpdateSysUser(ctx, req)
|
||||
func (b *userBiz) Update(ctx context.Context, req *form.User) error {
|
||||
user, err := b.store.GetSysUser(ctx, *req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysUserParams{
|
||||
ID: user.ID,
|
||||
Username: req.Username,
|
||||
HashedPassword: user.HashedPassword,
|
||||
Avatar: req.Avatar,
|
||||
Gender: req.Gender,
|
||||
DepartmentID: req.DepartmentID,
|
||||
RoleID: req.RoleID,
|
||||
Status: *req.Status,
|
||||
ChangePasswordAt: user.ChangePasswordAt,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if req.ChangePassword == "on" {
|
||||
hashedPassword, err := crypto.BcryptHashPassword(req.Password + user.Salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
arg.HashedPassword = hashedPassword
|
||||
arg.ChangePasswordAt = time.Now()
|
||||
}
|
||||
_, err = b.store.UpdateSysUser(ctx, arg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *userBiz) All(ctx context.Context) ([]*db.SysUser, error) {
|
||||
@@ -112,7 +178,7 @@ func (b *userBiz) XmSelect(ctx context.Context) ([]*view.XmSelect, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (b *userBiz) Login(ctx context.Context, req *req.Login) error {
|
||||
func (b *userBiz) Login(ctx context.Context, req *form.Login) error {
|
||||
log := &db.CreateSysUserLoginLogParams{
|
||||
CreatedAt: time.Now(),
|
||||
Email: req.Email,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/biz"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/pkg/binding"
|
||||
"management/internal/pkg/convertor"
|
||||
"management/internal/pkg/tpl"
|
||||
)
|
||||
@@ -95,61 +95,22 @@ func (h *departmentHandler) Edit(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *departmentHandler) Save(w http.ResponseWriter, r *http.Request) {
|
||||
id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0)
|
||||
ParentID := convertor.ConvertInt[int32](r.PostFormValue("ParentID"), 0)
|
||||
name := r.PostFormValue("Name")
|
||||
sort := convertor.ConvertInt[int32](r.PostFormValue("Sort"), 6666)
|
||||
status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 9999)
|
||||
|
||||
ctx := r.Context()
|
||||
var parent *db.SysDepartment
|
||||
if ParentID > 0 {
|
||||
var err error
|
||||
parent, err = h.biz.SystemV1().DepartmentBiz().Get(ctx, ParentID)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, "父级节点错误")
|
||||
return
|
||||
}
|
||||
var req form.Department
|
||||
if err := binding.Form.Bind(r, &req); err != nil {
|
||||
h.render.JSONERR(w, binding.ValidatorErrors(err))
|
||||
return
|
||||
}
|
||||
|
||||
if id == 0 {
|
||||
arg := db.CreateSysDepartmentParams{
|
||||
Name: name,
|
||||
ParentID: ParentID,
|
||||
ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID),
|
||||
Status: status,
|
||||
Sort: sort,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err := h.biz.SystemV1().DepartmentBiz().Create(ctx, &arg)
|
||||
ctx := r.Context()
|
||||
if *req.ID == 0 {
|
||||
err := h.biz.SystemV1().DepartmentBiz().Create(ctx, &req)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
h.render.JSONERR(w, "部门名称已存在")
|
||||
return
|
||||
}
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.render.JSONOK(w, "添加成功")
|
||||
} else {
|
||||
res, err := h.biz.SystemV1().DepartmentBiz().Get(ctx, id)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysDepartmentParams{
|
||||
ID: res.ID,
|
||||
Name: name,
|
||||
ParentID: ParentID,
|
||||
ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID),
|
||||
Status: status,
|
||||
Sort: sort,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = h.biz.SystemV1().DepartmentBiz().Update(ctx, arg)
|
||||
err := h.biz.SystemV1().DepartmentBiz().Update(ctx, &req)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
|
||||
11
internal/erpserver/handler/system/home.go
Normal file
11
internal/erpserver/handler/system/home.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package system
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (h *systemHandler) Home(w http.ResponseWriter, r *http.Request) {
|
||||
h.render.HTML(w, r, "home/home.tmpl", nil)
|
||||
}
|
||||
|
||||
func (h *systemHandler) Dashboard(w http.ResponseWriter, r *http.Request) {
|
||||
h.render.HTML(w, r, "home/dashboard.tmpl", nil)
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/biz"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/pkg/binding"
|
||||
"management/internal/pkg/convertor"
|
||||
"management/internal/pkg/tpl"
|
||||
)
|
||||
@@ -98,69 +98,22 @@ func (h *roleHandler) Edit(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *roleHandler) Save(w http.ResponseWriter, r *http.Request) {
|
||||
id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0)
|
||||
name := r.PostFormValue("Name")
|
||||
parentID := convertor.ConvertInt[int32](r.PostFormValue("ParentID"), 0)
|
||||
displayName := r.PostFormValue("DisplayName")
|
||||
sort := convertor.ConvertInt[int32](r.PostFormValue("Sort"), 6666)
|
||||
status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 0)
|
||||
|
||||
ctx := r.Context()
|
||||
var parent *db.SysRole
|
||||
if parentID > 0 {
|
||||
var err error
|
||||
parent, err = h.biz.SystemV1().RoleBiz().Get(ctx, parentID)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, "父级节点错误")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
parent = &db.SysRole{
|
||||
ID: 0,
|
||||
ParentID: 0,
|
||||
ParentPath: ",0,",
|
||||
}
|
||||
var req form.Role
|
||||
if err := binding.Form.Bind(r, &req); err != nil {
|
||||
h.render.JSONERR(w, binding.ValidatorErrors(err))
|
||||
return
|
||||
}
|
||||
|
||||
if id == 0 {
|
||||
arg := &db.CreateSysRoleParams{
|
||||
Name: name,
|
||||
DisplayName: displayName,
|
||||
Vip: false,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID),
|
||||
Status: status,
|
||||
Sort: sort,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err := h.biz.SystemV1().RoleBiz().Create(ctx, arg)
|
||||
ctx := r.Context()
|
||||
if *req.ID == 0 {
|
||||
err := h.biz.SystemV1().RoleBiz().Create(ctx, &req)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
h.render.JSONERR(w, "角色名称已存在")
|
||||
return
|
||||
}
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.render.JSONOK(w, "添加成功")
|
||||
} else {
|
||||
res, err := h.biz.SystemV1().RoleBiz().Get(ctx, id)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
arg := &db.UpdateSysRoleParams{
|
||||
ID: res.ID,
|
||||
DisplayName: displayName,
|
||||
Status: status,
|
||||
ParentID: parent.ID,
|
||||
ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID),
|
||||
Sort: sort,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = h.biz.SystemV1().RoleBiz().Update(ctx, arg)
|
||||
err := h.biz.SystemV1().RoleBiz().Update(ctx, &req)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
type SystemHandler interface {
|
||||
Home(w http.ResponseWriter, req *http.Request)
|
||||
Dashboard(w http.ResponseWriter, req *http.Request)
|
||||
UserHandler() UserHandler
|
||||
MenuHandler() MenuHandler
|
||||
RoleHandler() RoleHandler
|
||||
@@ -42,10 +43,6 @@ func NewSystemHandler(render tpl.Renderer, redis redis.IRedis, session session.I
|
||||
}
|
||||
}
|
||||
|
||||
func (h *systemHandler) Home(w http.ResponseWriter, r *http.Request) {
|
||||
h.render.HTML(w, r, "home/home.tmpl", nil)
|
||||
}
|
||||
|
||||
func (h *systemHandler) UserHandler() UserHandler {
|
||||
return NewUserHandler(h.render, h.session, h.biz, h.mi)
|
||||
}
|
||||
|
||||
@@ -3,23 +3,19 @@ package system
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"management/internal/db/model/dto"
|
||||
db "management/internal/db/sqlc"
|
||||
"management/internal/erpserver/biz"
|
||||
"management/internal/erpserver/model/req"
|
||||
"management/internal/erpserver/model/form"
|
||||
"management/internal/pkg/binding"
|
||||
"management/internal/pkg/convertor"
|
||||
"management/internal/pkg/crypto"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/middleware"
|
||||
"management/internal/pkg/rand"
|
||||
"management/internal/pkg/session"
|
||||
"management/internal/pkg/tpl"
|
||||
"management/internal/pkg/tpl/html"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/zhang2092/browser"
|
||||
)
|
||||
|
||||
@@ -74,7 +70,7 @@ func (h *userHandler) Edit(w http.ResponseWriter, r *http.Request) {
|
||||
if id > 0 {
|
||||
ctx := r.Context()
|
||||
if user, err := h.biz.SystemV1().UserBiz().Get(ctx, id); err == nil {
|
||||
user.HashedPassword = nil
|
||||
user.HashedPassword = []byte("********")
|
||||
sysUser = user
|
||||
}
|
||||
}
|
||||
@@ -84,109 +80,36 @@ func (h *userHandler) Edit(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *userHandler) Save(w http.ResponseWriter, r *http.Request) {
|
||||
id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0)
|
||||
email := r.PostFormValue("Email")
|
||||
username := r.PostFormValue("Username")
|
||||
password := r.PostFormValue("Password")
|
||||
changePassword := r.PostFormValue("ChangePassword")
|
||||
gender := convertor.ConvertInt[int32](r.PostFormValue("Gender"), 0)
|
||||
avatar := r.PostFormValue("File")
|
||||
status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 0)
|
||||
var req form.User
|
||||
if err := binding.Form.Bind(r, &req); err != nil {
|
||||
h.render.JSONERR(w, binding.ValidatorErrors(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
departmentID := convertor.ConvertInt[int32](r.PostFormValue("DepartmentID"), 0)
|
||||
var department *db.SysDepartment
|
||||
var err error
|
||||
if departmentID > 0 {
|
||||
department, err = h.biz.SystemV1().DepartmentBiz().Get(ctx, departmentID)
|
||||
if err != nil {
|
||||
if req.DepartmentID > 0 {
|
||||
if _, err := h.biz.SystemV1().DepartmentBiz().Get(ctx, req.DepartmentID); err != nil {
|
||||
h.render.JSONERR(w, "部门数据错误")
|
||||
return
|
||||
}
|
||||
}
|
||||
var role *db.SysRole
|
||||
roleID := convertor.ConvertInt[int32](r.PostFormValue("RoleID"), 0)
|
||||
if roleID > 0 {
|
||||
role, err = h.biz.SystemV1().RoleBiz().Get(ctx, roleID)
|
||||
if err != nil {
|
||||
if req.RoleID > 0 {
|
||||
if _, err := h.biz.SystemV1().RoleBiz().Get(ctx, req.RoleID); err != nil {
|
||||
h.render.JSONERR(w, "角色数据错误")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if id == 0 {
|
||||
salt, err := rand.String(10)
|
||||
if *req.ID == 0 {
|
||||
err := h.biz.SystemV1().UserBiz().Create(ctx, &req)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := crypto.BcryptHashPassword(password + salt)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
arg := &db.CreateSysUserParams{
|
||||
Uuid: uuid.Must(uuid.NewV7()),
|
||||
Email: email,
|
||||
Username: username,
|
||||
HashedPassword: hashedPassword,
|
||||
Salt: salt,
|
||||
Avatar: avatar,
|
||||
Gender: gender,
|
||||
DepartmentID: department.ID,
|
||||
RoleID: role.ID,
|
||||
Status: status,
|
||||
ChangePasswordAt: initTime,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = h.biz.SystemV1().UserBiz().Create(ctx, arg)
|
||||
if err != nil {
|
||||
if db.IsUniqueViolation(err) {
|
||||
h.render.JSONERR(w, "数据已存在")
|
||||
return
|
||||
}
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.render.JSONOK(w, "添加成功")
|
||||
} else {
|
||||
res, err := h.biz.SystemV1().UserBiz().Get(ctx, id)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
arg := &db.UpdateSysUserParams{
|
||||
ID: res.ID,
|
||||
Username: username,
|
||||
HashedPassword: res.HashedPassword,
|
||||
Avatar: avatar,
|
||||
Gender: int32(gender),
|
||||
DepartmentID: department.ID,
|
||||
RoleID: role.ID,
|
||||
Status: int32(status),
|
||||
ChangePasswordAt: res.ChangePasswordAt,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if changePassword == "on" {
|
||||
hashedPassword, err := crypto.BcryptHashPassword(password + res.Salt)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
}
|
||||
arg.HashedPassword = hashedPassword
|
||||
arg.ChangePasswordAt = time.Now()
|
||||
}
|
||||
_, err = h.biz.SystemV1().UserBiz().Update(ctx, arg)
|
||||
err := h.biz.SystemV1().UserBiz().Update(ctx, &req)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
@@ -261,7 +184,7 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var user dto.AuthorizeUser
|
||||
u := h.session.GetBytes(ctx, know.StoreName)
|
||||
if err := json.Unmarshal(u, &user); err == nil {
|
||||
// 判断租户是否一致, 一致则刷新令牌,跳转到首页
|
||||
// 判断用户是否登陆, 已经登陆则刷新令牌,跳转到首页
|
||||
if err := h.session.RenewToken(ctx); err == nil {
|
||||
h.session.Put(ctx, know.StoreName, u)
|
||||
http.Redirect(w, r, "/home.html", http.StatusFound)
|
||||
@@ -271,30 +194,16 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
h.session.Destroy(ctx)
|
||||
h.render.HTML(w, r, "oauth/login.tmpl", nil)
|
||||
case http.MethodPost:
|
||||
req := &req.Login{
|
||||
Email: strings.TrimSpace(r.PostFormValue("email")),
|
||||
Password: strings.TrimSpace(r.PostFormValue("password")),
|
||||
CaptchaID: strings.TrimSpace(r.PostFormValue("captcha_id")),
|
||||
Captcha: strings.TrimSpace(r.PostFormValue("captcha")),
|
||||
Ip: r.RemoteAddr,
|
||||
Referrer: r.Header.Get("Referer"),
|
||||
Url: r.URL.RequestURI(),
|
||||
defer r.Body.Close()
|
||||
var req form.Login
|
||||
if err := binding.Form.Bind(r, &req); err != nil {
|
||||
e := binding.ValidatorErrors(err)
|
||||
h.render.JSONERR(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Email) == 0 {
|
||||
h.render.JSON(w, tpl.Response{Success: false, Message: "请填写邮箱"})
|
||||
return
|
||||
}
|
||||
if len(req.Password) == 0 {
|
||||
h.render.JSON(w, tpl.Response{Success: false, Message: "请填写密码"})
|
||||
return
|
||||
}
|
||||
if len(req.Captcha) == 0 {
|
||||
h.render.JSON(w, tpl.Response{Success: false, Message: "请填写验证码"})
|
||||
return
|
||||
}
|
||||
if !h.biz.CommonV1().CaptchaBiz().Verify(req.CaptchaID, req.Captcha, true) {
|
||||
h.render.JSON(w, tpl.Response{Success: false, Message: "验证码错误"})
|
||||
h.render.JSONERR(w, "验证码错误")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -304,9 +213,12 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Ip = r.RemoteAddr
|
||||
req.Referrer = r.Header.Get("Referer")
|
||||
req.Url = r.URL.RequestURI()
|
||||
req.Os = br.Platform().Name()
|
||||
req.Browser = br.Name()
|
||||
err = h.biz.SystemV1().UserBiz().Login(ctx, req)
|
||||
err = h.biz.SystemV1().UserBiz().Login(ctx, &req)
|
||||
if err != nil {
|
||||
h.render.JSONERR(w, err.Error())
|
||||
return
|
||||
|
||||
@@ -38,7 +38,8 @@ func NewRouter(handler handler.IHandler, mw mw.IMiddleware) *chi.Mux {
|
||||
r.With(mw.Authorize, mw.Audit).Post("/upload/file", handler.CommonHandler().UploadHandler().File)
|
||||
r.With(mw.Authorize, mw.Audit).Post("/upload/mutilfile", handler.CommonHandler().UploadHandler().MutilFiles)
|
||||
|
||||
r.With(mw.Authorize, mw.Audit).Get("/home.html", handler.SystemHandler().Home)
|
||||
r.With(mw.Authorize).Get("/home.html", handler.SystemHandler().Home)
|
||||
r.With(mw.Authorize).Get("/dashboard", handler.SystemHandler().Dashboard)
|
||||
r.With(mw.Authorize).Get("/pear.json", handler.SystemHandler().ConfigHandler().Pear)
|
||||
|
||||
r.Route("/system", func(r chi.Router) {
|
||||
|
||||
9
internal/erpserver/model/form/department.go
Normal file
9
internal/erpserver/model/form/department.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package form
|
||||
|
||||
type Department struct {
|
||||
ID *int32 `form:"id" binding:"required"`
|
||||
Name string `form:"name" binding:"required"`
|
||||
ParentID *int32 `form:"parent_id" binding:"required"`
|
||||
Sort *int32 `form:"sort"`
|
||||
Status *int32 `form:"status" binding:"required"`
|
||||
}
|
||||
10
internal/erpserver/model/form/role.go
Normal file
10
internal/erpserver/model/form/role.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package form
|
||||
|
||||
type Role struct {
|
||||
ID *int32 `form:"id" binding:"required"`
|
||||
Name string `form:"name" binding:"required"`
|
||||
ParentID *int32 `form:"parent_id" binding:"required"`
|
||||
DisplayName string `form:"display_name" binding:"required"`
|
||||
Sort *int32 `form:"sort"`
|
||||
Status *int32 `form:"status" binding:"required"`
|
||||
}
|
||||
28
internal/erpserver/model/form/user.go
Normal file
28
internal/erpserver/model/form/user.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package form
|
||||
|
||||
type Login struct {
|
||||
Email string `form:"email" binding:"required,email"`
|
||||
Password string `form:"password" binding:"required,min=6"`
|
||||
Captcha string `form:"captcha" binding:"required"`
|
||||
CaptchaID string `form:"captcha_id" binding:"required"`
|
||||
|
||||
// 平台信息
|
||||
Os string
|
||||
Ip string
|
||||
Browser string
|
||||
Referrer string
|
||||
Url string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID *int32 `form:"id" binding:"required"`
|
||||
Email string `form:"email" binding:"required,email"`
|
||||
Username string `form:"username" binding:"required"`
|
||||
Password string `form:"password" binding:"required,min=6"`
|
||||
ChangePassword string `form:"change_password"`
|
||||
Avatar string `form:"File"`
|
||||
Gender int32 `form:"gender"`
|
||||
DepartmentID int32 `form:"department_id"`
|
||||
RoleID int32 `form:"role_id"`
|
||||
Status *int32 `form:"status" binding:"required"`
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package req
|
||||
|
||||
type Login struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Captcha string `json:"captcha"`
|
||||
CaptchaID string `json:"captcha_id"`
|
||||
|
||||
// 平台信息
|
||||
Os string `json:"os"`
|
||||
Ip string `json:"ip"`
|
||||
Browser string `json:"browser"`
|
||||
Referrer string `json:"referrer"`
|
||||
Url string `json:"url"`
|
||||
}
|
||||
139
internal/pkg/binding/binding.go
Normal file
139
internal/pkg/binding/binding.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
MIMEJSON = "application/json"
|
||||
MIMEHTML = "text/html"
|
||||
MIMEXML = "application/xml"
|
||||
MIMEXML2 = "text/xml"
|
||||
MIMEPlain = "text/plain"
|
||||
MIMEPOSTForm = "application/x-www-form-urlencoded"
|
||||
MIMEMultipartPOSTForm = "multipart/form-data"
|
||||
MIMEPROTOBUF = "application/x-protobuf"
|
||||
MIMEMSGPACK = "application/x-msgpack"
|
||||
MIMEMSGPACK2 = "application/msgpack"
|
||||
MIMEYAML = "application/x-yaml"
|
||||
MIMEYAML2 = "application/yaml"
|
||||
MIMETOML = "application/toml"
|
||||
)
|
||||
|
||||
type Binding interface {
|
||||
Name() string
|
||||
Bind(*http.Request, any) error
|
||||
}
|
||||
|
||||
type BindingBody interface {
|
||||
Binding
|
||||
BindBody([]byte, any) error
|
||||
}
|
||||
|
||||
type BindingUri interface {
|
||||
Name() string
|
||||
BindUri(map[string][]string, any) error
|
||||
}
|
||||
|
||||
// StructValidator is the minimal interface which needs to be implemented in
|
||||
// order for it to be used as the validator engine for ensuring the correctness
|
||||
// of the request. Gin provides a default implementation for this using
|
||||
// https://github.com/go-playground/validator/tree/v10.6.1.
|
||||
type StructValidator interface {
|
||||
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||
// If the received type is a slice|array, the validation should be performed travel on every element.
|
||||
// If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned.
|
||||
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||
// Otherwise nil must be returned.
|
||||
ValidateStruct(any) error
|
||||
|
||||
// Engine returns the underlying validator engine which powers the
|
||||
// StructValidator implementation.
|
||||
Engine() any
|
||||
}
|
||||
|
||||
// Validator is the default validator which implements the StructValidator
|
||||
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
|
||||
// under the hood.
|
||||
var Validator StructValidator = &defaultValidator{}
|
||||
|
||||
var (
|
||||
JSON BindingBody = jsonBinding{}
|
||||
// XML BindingBody = xmlBinding{}
|
||||
Form Binding = formBinding{}
|
||||
Query Binding = queryBinding{}
|
||||
FormPost Binding = formPostBinding{}
|
||||
FormMultipart Binding = formMultipartBinding{}
|
||||
// ProtoBuf BindingBody = protobufBinding{}
|
||||
// MsgPack BindingBody = msgpackBinding{}
|
||||
// YAML BindingBody = yamlBinding{}
|
||||
// Uri BindingUri = uriBinding{}
|
||||
// Header Binding = headerBinding{}
|
||||
// Plain BindingBody = plainBinding{}
|
||||
// TOML BindingBody = tomlBinding{}
|
||||
)
|
||||
|
||||
// Default returns the appropriate Binding instance based on the HTTP method
|
||||
// and the content type.
|
||||
func Default(method, contentType string) Binding {
|
||||
if method == http.MethodGet {
|
||||
return Form
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case MIMEJSON:
|
||||
return JSON
|
||||
// case MIMEXML, MIMEXML2:
|
||||
// return XML
|
||||
// case MIMEPROTOBUF:
|
||||
// return ProtoBuf
|
||||
// case MIMEMSGPACK, MIMEMSGPACK2:
|
||||
// return MsgPack
|
||||
// case MIMEYAML, MIMEYAML2:
|
||||
// return YAML
|
||||
// case MIMETOML:
|
||||
// return TOML
|
||||
case MIMEMultipartPOSTForm:
|
||||
return FormMultipart
|
||||
default: // case MIMEPOSTForm:
|
||||
return Form
|
||||
}
|
||||
}
|
||||
|
||||
func validate(obj any) error {
|
||||
if Validator == nil {
|
||||
return nil
|
||||
}
|
||||
return Validator.ValidateStruct(obj)
|
||||
}
|
||||
|
||||
// ShouldBind checks the Method and Content-Type to select a binding engine automatically,
|
||||
// Depending on the "Content-Type" header different bindings are used, for example:
|
||||
//
|
||||
// "application/json" --> JSON binding
|
||||
// "application/xml" --> XML binding
|
||||
//
|
||||
// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input.
|
||||
// It decodes the json payload into the struct specified as a pointer.
|
||||
// Like c.Bind() but this method does not set the response status code to 400 or abort if input is not valid.
|
||||
func ShouldBind(r *http.Request, obj any) error {
|
||||
b := Default(r.Method, r.Header.Get("Content-Type"))
|
||||
return ShouldBindWith(r, obj, b)
|
||||
}
|
||||
|
||||
// ShouldBindJSON is a shortcut for c.ShouldBindWith(obj, JSON).
|
||||
func ShouldBindJSON(r *http.Request, obj any) error {
|
||||
return ShouldBindWith(r, obj, JSON)
|
||||
}
|
||||
|
||||
// ShouldBindWith binds the passed struct pointer using the specified binding engine.
|
||||
// See the binding package.
|
||||
func ShouldBindWith(r *http.Request, obj any, b Binding) error {
|
||||
return b.Bind(r, obj)
|
||||
}
|
||||
|
||||
// ShouldBindQuery is a shortcut for c.ShouldBindWith(obj, Query).
|
||||
func ShouldBindQuery(r *http.Request, obj any) error {
|
||||
return ShouldBindWith(r, obj, Query)
|
||||
}
|
||||
15
internal/pkg/binding/byteconv.go
Normal file
15
internal/pkg/binding/byteconv.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package binding
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// stringToBytes converts string to byte slice without a memory allocation.
|
||||
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
|
||||
func stringToBytes(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
|
||||
// bytesToString converts byte slice to string without a memory allocation.
|
||||
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
|
||||
func bytesToString(b []byte) string {
|
||||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
91
internal/pkg/binding/default_validator.go
Normal file
91
internal/pkg/binding/default_validator.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
type defaultValidator struct {
|
||||
once sync.Once
|
||||
validate *validator.Validate
|
||||
}
|
||||
|
||||
type SliceValidationError []error
|
||||
|
||||
// Error concatenates all error elements in SliceValidationError into a single string separated by \n.
|
||||
func (err SliceValidationError) Error() string {
|
||||
if len(err) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(err); i++ {
|
||||
if err[i] != nil {
|
||||
if b.Len() > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("[" + strconv.Itoa(i) + "]: " + err[i].Error())
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
var _ StructValidator = (*defaultValidator)(nil)
|
||||
|
||||
// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
|
||||
func (v *defaultValidator) ValidateStruct(obj any) error {
|
||||
if obj == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(obj)
|
||||
switch value.Kind() {
|
||||
case reflect.Ptr:
|
||||
if value.Elem().Kind() != reflect.Struct {
|
||||
return v.ValidateStruct(value.Elem().Interface())
|
||||
}
|
||||
return v.validateStruct(obj)
|
||||
case reflect.Struct:
|
||||
return v.validateStruct(obj)
|
||||
case reflect.Slice, reflect.Array:
|
||||
count := value.Len()
|
||||
validateRet := make(SliceValidationError, 0)
|
||||
for i := 0; i < count; i++ {
|
||||
if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
|
||||
validateRet = append(validateRet, err)
|
||||
}
|
||||
}
|
||||
if len(validateRet) == 0 {
|
||||
return nil
|
||||
}
|
||||
return validateRet
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// validateStruct receives struct type
|
||||
func (v *defaultValidator) validateStruct(obj any) error {
|
||||
v.lazyinit()
|
||||
return v.validate.Struct(obj)
|
||||
}
|
||||
|
||||
// Engine returns the underlying validator engine which powers the default
|
||||
// Validator instance. This is useful if you want to register custom validations
|
||||
// or struct level validations. See validator GoDoc for more info -
|
||||
// https://pkg.go.dev/github.com/go-playground/validator/v10
|
||||
func (v *defaultValidator) Engine() any {
|
||||
v.lazyinit()
|
||||
return v.validate
|
||||
}
|
||||
|
||||
func (v *defaultValidator) lazyinit() {
|
||||
v.once.Do(func() {
|
||||
v.validate = validator.New()
|
||||
v.validate.SetTagName("binding")
|
||||
})
|
||||
}
|
||||
60
internal/pkg/binding/form.go
Normal file
60
internal/pkg/binding/form.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const defaultMemory = 32 << 20
|
||||
|
||||
type (
|
||||
formBinding struct{}
|
||||
formPostBinding struct{}
|
||||
formMultipartBinding struct{}
|
||||
)
|
||||
|
||||
func (formBinding) Name() string {
|
||||
return "form"
|
||||
}
|
||||
|
||||
func (formBinding) Bind(req *http.Request, obj any) error {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := req.ParseMultipartForm(defaultMemory); err != nil && !errors.Is(err, http.ErrNotMultipart) {
|
||||
return err
|
||||
}
|
||||
if err := mapForm(obj, req.Form); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
|
||||
func (formPostBinding) Name() string {
|
||||
return "form-urlencoded"
|
||||
}
|
||||
|
||||
func (formPostBinding) Bind(req *http.Request, obj any) error {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mapForm(obj, req.PostForm); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
|
||||
func (formMultipartBinding) Name() string {
|
||||
return "multipart/form-data"
|
||||
}
|
||||
|
||||
func (formMultipartBinding) Bind(req *http.Request, obj any) error {
|
||||
if err := req.ParseMultipartForm(defaultMemory); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validate(obj)
|
||||
}
|
||||
473
internal/pkg/binding/form_mapping.go
Normal file
473
internal/pkg/binding/form_mapping.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnknownType = errors.New("unknown type")
|
||||
|
||||
// ErrConvertMapStringSlice can not convert to map[string][]string
|
||||
ErrConvertMapStringSlice = errors.New("can not convert to map slices of strings")
|
||||
|
||||
// ErrConvertToMapString can not convert to map[string]string
|
||||
ErrConvertToMapString = errors.New("can not convert to map of strings")
|
||||
)
|
||||
|
||||
func mapURI(ptr any, m map[string][]string) error {
|
||||
return mapFormByTag(ptr, m, "uri")
|
||||
}
|
||||
|
||||
func mapForm(ptr any, form map[string][]string) error {
|
||||
return mapFormByTag(ptr, form, "form")
|
||||
}
|
||||
|
||||
func MapFormWithTag(ptr any, form map[string][]string, tag string) error {
|
||||
return mapFormByTag(ptr, form, tag)
|
||||
}
|
||||
|
||||
var emptyField = reflect.StructField{}
|
||||
|
||||
func mapFormByTag(ptr any, form map[string][]string, tag string) error {
|
||||
// Check if ptr is a map
|
||||
ptrVal := reflect.ValueOf(ptr)
|
||||
var pointed any
|
||||
if ptrVal.Kind() == reflect.Ptr {
|
||||
ptrVal = ptrVal.Elem()
|
||||
pointed = ptrVal.Interface()
|
||||
}
|
||||
if ptrVal.Kind() == reflect.Map &&
|
||||
ptrVal.Type().Key().Kind() == reflect.String {
|
||||
if pointed != nil {
|
||||
ptr = pointed
|
||||
}
|
||||
return setFormMap(ptr, form)
|
||||
}
|
||||
|
||||
return mappingByPtr(ptr, formSource(form), tag)
|
||||
}
|
||||
|
||||
// setter tries to set value on a walking by fields of a struct
|
||||
type setter interface {
|
||||
TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSet bool, err error)
|
||||
}
|
||||
|
||||
type formSource map[string][]string
|
||||
|
||||
var _ setter = formSource(nil)
|
||||
|
||||
// TrySet tries to set a value by request's form source (like map[string][]string)
|
||||
func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSet bool, err error) {
|
||||
return setByForm(value, field, form, tagValue, opt)
|
||||
}
|
||||
|
||||
func mappingByPtr(ptr any, setter setter, tag string) error {
|
||||
_, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag)
|
||||
return err
|
||||
}
|
||||
|
||||
func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
|
||||
if field.Tag.Get(tag) == "-" { // just ignoring this field
|
||||
return false, nil
|
||||
}
|
||||
|
||||
vKind := value.Kind()
|
||||
|
||||
if vKind == reflect.Ptr {
|
||||
var isNew bool
|
||||
vPtr := value
|
||||
if value.IsNil() {
|
||||
isNew = true
|
||||
vPtr = reflect.New(value.Type().Elem())
|
||||
}
|
||||
isSet, err := mapping(vPtr.Elem(), field, setter, tag)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if isNew && isSet {
|
||||
value.Set(vPtr)
|
||||
}
|
||||
return isSet, nil
|
||||
}
|
||||
|
||||
if vKind != reflect.Struct || !field.Anonymous {
|
||||
ok, err := tryToSetValue(value, field, setter, tag)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if vKind == reflect.Struct {
|
||||
tValue := value.Type()
|
||||
|
||||
var isSet bool
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
sf := tValue.Field(i)
|
||||
if sf.PkgPath != "" && !sf.Anonymous { // unexported
|
||||
continue
|
||||
}
|
||||
ok, err := mapping(value.Field(i), sf, setter, tag)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
isSet = isSet || ok
|
||||
}
|
||||
return isSet, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type setOptions struct {
|
||||
isDefaultExists bool
|
||||
defaultValue string
|
||||
}
|
||||
|
||||
func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
|
||||
var tagValue string
|
||||
var setOpt setOptions
|
||||
|
||||
tagValue = field.Tag.Get(tag)
|
||||
tagValue, opts := head(tagValue, ",")
|
||||
|
||||
if tagValue == "" { // default value is FieldName
|
||||
tagValue = field.Name
|
||||
}
|
||||
if tagValue == "" { // when field is "emptyField" variable
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var opt string
|
||||
for len(opts) > 0 {
|
||||
opt, opts = head(opts, ",")
|
||||
|
||||
if k, v := head(opt, "="); k == "default" {
|
||||
setOpt.isDefaultExists = true
|
||||
setOpt.defaultValue = v
|
||||
}
|
||||
}
|
||||
|
||||
return setter.TrySet(value, field, tagValue, setOpt)
|
||||
}
|
||||
|
||||
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
|
||||
type BindUnmarshaler interface {
|
||||
// UnmarshalParam decodes and assigns a value from an form or query param.
|
||||
UnmarshalParam(param string) error
|
||||
}
|
||||
|
||||
// trySetCustom tries to set a custom type value
|
||||
// If the value implements the BindUnmarshaler interface, it will be used to set the value, we will return `true`
|
||||
// to skip the default value setting.
|
||||
func trySetCustom(val string, value reflect.Value) (isSet bool, err error) {
|
||||
switch v := value.Addr().Interface().(type) {
|
||||
case BindUnmarshaler:
|
||||
return true, v.UnmarshalParam(val)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func trySplit(vs []string, field reflect.StructField) (newVs []string, err error) {
|
||||
cfTag := field.Tag.Get("collection_format")
|
||||
if cfTag == "" || cfTag == "multi" {
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
var sep string
|
||||
switch cfTag {
|
||||
case "csv":
|
||||
sep = ","
|
||||
case "ssv":
|
||||
sep = " "
|
||||
case "tsv":
|
||||
sep = "\t"
|
||||
case "pipes":
|
||||
sep = "|"
|
||||
default:
|
||||
return vs, fmt.Errorf("%s is not supported in the collection_format. (csv, ssv, pipes)", cfTag)
|
||||
}
|
||||
|
||||
totalLength := 0
|
||||
for _, v := range vs {
|
||||
totalLength += strings.Count(v, sep) + 1
|
||||
}
|
||||
newVs = make([]string, 0, totalLength)
|
||||
for _, v := range vs {
|
||||
newVs = append(newVs, strings.Split(v, sep)...)
|
||||
}
|
||||
|
||||
return newVs, nil
|
||||
}
|
||||
|
||||
func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSet bool, err error) {
|
||||
vs, ok := form[tagValue]
|
||||
if !ok && !opt.isDefaultExists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
if !ok {
|
||||
vs = []string{opt.defaultValue}
|
||||
}
|
||||
|
||||
if ok, err = trySetCustom(vs[0], value); ok {
|
||||
return ok, err
|
||||
}
|
||||
|
||||
if vs, err = trySplit(vs, field); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, setSlice(vs, value, field)
|
||||
case reflect.Array:
|
||||
if !ok {
|
||||
vs = []string{opt.defaultValue}
|
||||
}
|
||||
|
||||
if ok, err = trySetCustom(vs[0], value); ok {
|
||||
return ok, err
|
||||
}
|
||||
|
||||
if vs, err = trySplit(vs, field); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(vs) != value.Len() {
|
||||
return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String())
|
||||
}
|
||||
|
||||
return true, setArray(vs, value, field)
|
||||
default:
|
||||
var val string
|
||||
if !ok {
|
||||
val = opt.defaultValue
|
||||
}
|
||||
|
||||
if len(vs) > 0 {
|
||||
val = vs[0]
|
||||
}
|
||||
if ok, err := trySetCustom(val, value); ok {
|
||||
return ok, err
|
||||
}
|
||||
return true, setWithProperType(val, value, field)
|
||||
}
|
||||
}
|
||||
|
||||
func setWithProperType(val string, value reflect.Value, field reflect.StructField) error {
|
||||
switch value.Kind() {
|
||||
case reflect.Int:
|
||||
return setIntField(val, 0, value)
|
||||
case reflect.Int8:
|
||||
return setIntField(val, 8, value)
|
||||
case reflect.Int16:
|
||||
return setIntField(val, 16, value)
|
||||
case reflect.Int32:
|
||||
return setIntField(val, 32, value)
|
||||
case reflect.Int64:
|
||||
switch value.Interface().(type) {
|
||||
case time.Duration:
|
||||
return setTimeDuration(val, value)
|
||||
}
|
||||
return setIntField(val, 64, value)
|
||||
case reflect.Uint:
|
||||
return setUintField(val, 0, value)
|
||||
case reflect.Uint8:
|
||||
return setUintField(val, 8, value)
|
||||
case reflect.Uint16:
|
||||
return setUintField(val, 16, value)
|
||||
case reflect.Uint32:
|
||||
return setUintField(val, 32, value)
|
||||
case reflect.Uint64:
|
||||
return setUintField(val, 64, value)
|
||||
case reflect.Bool:
|
||||
return setBoolField(val, value)
|
||||
case reflect.Float32:
|
||||
return setFloatField(val, 32, value)
|
||||
case reflect.Float64:
|
||||
return setFloatField(val, 64, value)
|
||||
case reflect.String:
|
||||
value.SetString(val)
|
||||
case reflect.Struct:
|
||||
switch value.Interface().(type) {
|
||||
case time.Time:
|
||||
return setTimeField(val, field, value)
|
||||
case multipart.FileHeader:
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(stringToBytes(val), value.Addr().Interface())
|
||||
case reflect.Map:
|
||||
return json.Unmarshal(stringToBytes(val), value.Addr().Interface())
|
||||
case reflect.Ptr:
|
||||
if !value.Elem().IsValid() {
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
}
|
||||
return setWithProperType(val, value.Elem(), field)
|
||||
default:
|
||||
return errUnknownType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIntField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0"
|
||||
}
|
||||
intVal, err := strconv.ParseInt(val, 10, bitSize)
|
||||
if err == nil {
|
||||
field.SetInt(intVal)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func setUintField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0"
|
||||
}
|
||||
uintVal, err := strconv.ParseUint(val, 10, bitSize)
|
||||
if err == nil {
|
||||
field.SetUint(uintVal)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func setBoolField(val string, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "false"
|
||||
}
|
||||
boolVal, err := strconv.ParseBool(val)
|
||||
if err == nil {
|
||||
field.SetBool(boolVal)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func setFloatField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0.0"
|
||||
}
|
||||
floatVal, err := strconv.ParseFloat(val, bitSize)
|
||||
if err == nil {
|
||||
field.SetFloat(floatVal)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
|
||||
timeFormat := structField.Tag.Get("time_format")
|
||||
if timeFormat == "" {
|
||||
timeFormat = time.RFC3339
|
||||
}
|
||||
|
||||
switch tf := strings.ToLower(timeFormat); tf {
|
||||
case "unix", "unixnano":
|
||||
tv, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d := time.Duration(1)
|
||||
if tf == "unixnano" {
|
||||
d = time.Second
|
||||
}
|
||||
|
||||
t := time.Unix(tv/int64(d), tv%int64(d))
|
||||
value.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
value.Set(reflect.ValueOf(time.Time{}))
|
||||
return nil
|
||||
}
|
||||
|
||||
l := time.Local
|
||||
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
|
||||
l = time.UTC
|
||||
}
|
||||
|
||||
if locTag := structField.Tag.Get("time_location"); locTag != "" {
|
||||
loc, err := time.LoadLocation(locTag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l = loc
|
||||
}
|
||||
|
||||
t, err := time.ParseInLocation(timeFormat, val, l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
}
|
||||
|
||||
func setArray(vals []string, value reflect.Value, field reflect.StructField) error {
|
||||
for i, s := range vals {
|
||||
err := setWithProperType(s, value.Index(i), field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setSlice(vals []string, value reflect.Value, field reflect.StructField) error {
|
||||
slice := reflect.MakeSlice(value.Type(), len(vals), len(vals))
|
||||
err := setArray(vals, slice, field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.Set(slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setTimeDuration(val string, value reflect.Value) error {
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.Set(reflect.ValueOf(d))
|
||||
return nil
|
||||
}
|
||||
|
||||
func head(str, sep string) (head string, tail string) {
|
||||
head, tail, _ = strings.Cut(str, sep)
|
||||
return head, tail
|
||||
}
|
||||
|
||||
func setFormMap(ptr any, form map[string][]string) error {
|
||||
el := reflect.TypeOf(ptr).Elem()
|
||||
|
||||
if el.Kind() == reflect.Slice {
|
||||
ptrMap, ok := ptr.(map[string][]string)
|
||||
if !ok {
|
||||
return ErrConvertMapStringSlice
|
||||
}
|
||||
for k, v := range form {
|
||||
ptrMap[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ptrMap, ok := ptr.(map[string]string)
|
||||
if !ok {
|
||||
return ErrConvertToMapString
|
||||
}
|
||||
for k, v := range form {
|
||||
ptrMap[k] = v[len(v)-1] // pick last
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
61
internal/pkg/binding/init.go
Normal file
61
internal/pkg/binding/init.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/locales/en"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
enTranslations "github.com/go-playground/validator/v10/translations/en"
|
||||
chTranslations "github.com/go-playground/validator/v10/translations/zh"
|
||||
)
|
||||
|
||||
var trans ut.Translator
|
||||
|
||||
// loca 通常取决于 http 请求头的 'Accept-Language'
|
||||
func SetValidatorTrans(local string) (err error) {
|
||||
if v, ok := Validator.Engine().(*validator.Validate); ok {
|
||||
zhT := zh.New() // chinese
|
||||
enT := en.New() // english
|
||||
uni := ut.New(enT, zhT, enT)
|
||||
|
||||
var o bool
|
||||
trans, o = uni.GetTranslator(local)
|
||||
if !o {
|
||||
return fmt.Errorf("uni.GetTranslator(%s) failed", local)
|
||||
}
|
||||
// register translate
|
||||
// 注册翻译器
|
||||
switch local {
|
||||
case "en":
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
case "zh":
|
||||
err = chTranslations.RegisterDefaultTranslations(v, trans)
|
||||
default:
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ValidatorErrors(err error) string {
|
||||
if errors, ok := err.(validator.ValidationErrors); ok {
|
||||
errs := make(map[string]any)
|
||||
for _, e := range errors {
|
||||
errs[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1)
|
||||
}
|
||||
|
||||
// 将 map 转换为 JSON 格式的字节切片
|
||||
jsonData, err := json.MarshalIndent(errs, "", " ")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return string(jsonData)
|
||||
} else {
|
||||
return err.Error()
|
||||
}
|
||||
}
|
||||
51
internal/pkg/binding/json.go
Normal file
51
internal/pkg/binding/json.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// EnableDecoderUseNumber is used to call the UseNumber method on the JSON
|
||||
// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an
|
||||
// any as a Number instead of as a float64.
|
||||
var EnableDecoderUseNumber = false
|
||||
|
||||
// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method
|
||||
// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to
|
||||
// return an error when the destination is a struct and the input contains object
|
||||
// keys which do not match any non-ignored, exported fields in the destination.
|
||||
var EnableDecoderDisallowUnknownFields = false
|
||||
|
||||
type jsonBinding struct{}
|
||||
|
||||
func (jsonBinding) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
func (jsonBinding) Bind(req *http.Request, obj any) error {
|
||||
if req == nil || req.Body == nil {
|
||||
return errors.New("invalid request")
|
||||
}
|
||||
return decodeJSON(req.Body, obj)
|
||||
}
|
||||
|
||||
func (jsonBinding) BindBody(body []byte, obj any) error {
|
||||
return decodeJSON(bytes.NewReader(body), obj)
|
||||
}
|
||||
|
||||
func decodeJSON(r io.Reader, obj any) error {
|
||||
decoder := json.NewDecoder(r)
|
||||
if EnableDecoderUseNumber {
|
||||
decoder.UseNumber()
|
||||
}
|
||||
if EnableDecoderDisallowUnknownFields {
|
||||
decoder.DisallowUnknownFields()
|
||||
}
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
70
internal/pkg/binding/multipart_form_mapping.go
Normal file
70
internal/pkg/binding/multipart_form_mapping.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type multipartRequest http.Request
|
||||
|
||||
var _ setter = (*multipartRequest)(nil)
|
||||
|
||||
var (
|
||||
// ErrMultiFileHeader multipart.FileHeader invalid
|
||||
ErrMultiFileHeader = errors.New("unsupported field type for multipart.FileHeader")
|
||||
|
||||
// ErrMultiFileHeaderLenInvalid array for []*multipart.FileHeader len invalid
|
||||
ErrMultiFileHeaderLenInvalid = errors.New("unsupported len of array for []*multipart.FileHeader")
|
||||
)
|
||||
|
||||
// TrySet tries to set a value by the multipart request with the binding a form file
|
||||
func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (bool, error) {
|
||||
if files := r.MultipartForm.File[key]; len(files) != 0 {
|
||||
return setByMultipartFormFile(value, field, files)
|
||||
}
|
||||
|
||||
return setByForm(value, field, r.MultipartForm.Value, key, opt)
|
||||
}
|
||||
|
||||
func setByMultipartFormFile(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) {
|
||||
switch value.Kind() {
|
||||
case reflect.Ptr:
|
||||
switch value.Interface().(type) {
|
||||
case *multipart.FileHeader:
|
||||
value.Set(reflect.ValueOf(files[0]))
|
||||
return true, nil
|
||||
}
|
||||
case reflect.Struct:
|
||||
switch value.Interface().(type) {
|
||||
case multipart.FileHeader:
|
||||
value.Set(reflect.ValueOf(*files[0]))
|
||||
return true, nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
slice := reflect.MakeSlice(value.Type(), len(files), len(files))
|
||||
isSet, err = setArrayOfMultipartFormFiles(slice, field, files)
|
||||
if err != nil || !isSet {
|
||||
return isSet, err
|
||||
}
|
||||
value.Set(slice)
|
||||
return true, nil
|
||||
case reflect.Array:
|
||||
return setArrayOfMultipartFormFiles(value, field, files)
|
||||
}
|
||||
return false, ErrMultiFileHeader
|
||||
}
|
||||
|
||||
func setArrayOfMultipartFormFiles(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) {
|
||||
if value.Len() != len(files) {
|
||||
return false, ErrMultiFileHeaderLenInvalid
|
||||
}
|
||||
for i := range files {
|
||||
set, err := setByMultipartFormFile(value.Index(i), field, files[i:i+1])
|
||||
if err != nil || !set {
|
||||
return set, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
17
internal/pkg/binding/query.go
Normal file
17
internal/pkg/binding/query.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package binding
|
||||
|
||||
import "net/http"
|
||||
|
||||
type queryBinding struct{}
|
||||
|
||||
func (queryBinding) Name() string {
|
||||
return "query"
|
||||
}
|
||||
|
||||
func (queryBinding) Bind(req *http.Request, obj any) error {
|
||||
values := req.URL.Query()
|
||||
if err := mapForm(obj, values); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
16
internal/pkg/convertor/path.go
Normal file
16
internal/pkg/convertor/path.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package convertor
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"management/internal/pkg/sliceutil"
|
||||
)
|
||||
|
||||
func HandleParentPath(parentPath string) string {
|
||||
parentPath = strings.ReplaceAll(parentPath, ",,", ",")
|
||||
paths := sliceutil.RemoveDuplicatesWithMap(strings.Split(parentPath, ","))
|
||||
sort.Strings(paths)
|
||||
parentPath = strings.Join(paths, ",") + ","
|
||||
return parentPath
|
||||
}
|
||||
@@ -26,7 +26,7 @@ var PearJson = &dto.PearConfig{
|
||||
Max: "30",
|
||||
Index: dto.Index{
|
||||
Id: "10",
|
||||
Href: "",
|
||||
Href: "/dashboard",
|
||||
Title: "首页",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
var defaultMenus = map[string]bool{
|
||||
"/home.html": true,
|
||||
"/dashboard": true,
|
||||
"/system/menus": true,
|
||||
"/upload/img": true,
|
||||
"/upload/file": true,
|
||||
|
||||
31
internal/pkg/sliceutil/sliceutil.go
Normal file
31
internal/pkg/sliceutil/sliceutil.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package sliceutil
|
||||
|
||||
import "sort"
|
||||
|
||||
// 使用 map 去除重复元素
|
||||
func RemoveDuplicatesWithMap[T comparable](slice []T) []T {
|
||||
result := make([]T, 0, len(slice))
|
||||
seen := make(map[T]bool)
|
||||
for _, v := range slice {
|
||||
if _, ok := seen[v]; !ok {
|
||||
seen[v] = true
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// 先排序再去重
|
||||
func RemoveDuplicatesWithSort(slice []int) []int {
|
||||
if len(slice) == 0 {
|
||||
return slice
|
||||
}
|
||||
sort.Ints(slice)
|
||||
result := []int{slice[0]}
|
||||
for i := 1; i < len(slice); i++ {
|
||||
if slice[i] != slice[i-1] {
|
||||
result = append(result, slice[i])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func splitWordsToLower(s string) []string {
|
||||
// upperIndex 获得一个int slice,其元素是一个字符串的所有大写字母索引
|
||||
func upperIndex(s string) []int {
|
||||
var res []int
|
||||
for i := 0; i < len(s); i++ {
|
||||
for i := range len(s) {
|
||||
if 64 < s[i] && s[i] < 91 {
|
||||
res = append(res, i)
|
||||
}
|
||||
|
||||
@@ -53,5 +53,9 @@ func (r *render) Methods() map[string]any {
|
||||
return template.HTML(strings.Join(s, ","))
|
||||
}
|
||||
|
||||
res["toString"] = func(b []byte) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user