add csrf check

This commit is contained in:
kenneth
2023-12-04 07:19:53 +00:00
parent 8f89c19e12
commit 91edab2f9b
9 changed files with 82 additions and 26 deletions

View File

@@ -12,6 +12,7 @@ import (
"syscall"
"time"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/zhang2092/mediahls/internal/db"
@@ -57,6 +58,13 @@ func (server *Server) setupRouter() {
router.PathPrefix("/statics/").Handler(http.StripPrefix("/statics/", http.FileServer(http.Dir("web/statics"))))
router.PathPrefix("/upload/imgs").Handler(http.StripPrefix("/upload/imgs/", http.FileServer(http.Dir("upload/imgs"))))
csrfMiddleware := csrf.Protect(
[]byte(securecookie.GenerateRandomKey(32)),
csrf.Secure(false),
csrf.HttpOnly(true),
)
router.Use(csrfMiddleware)
router.HandleFunc("/register", server.registerView).Methods(http.MethodGet)
router.HandleFunc("/register", server.register).Methods(http.MethodPost)
router.HandleFunc("/login", server.loginView).Methods(http.MethodGet)

View File

@@ -2,9 +2,11 @@ package handlers
import (
"database/sql"
"html/template"
"net/http"
"time"
"github.com/gorilla/csrf"
"github.com/zhang2092/mediahls/internal/db"
"github.com/zhang2092/mediahls/internal/pkg/cookie"
pwd "github.com/zhang2092/mediahls/internal/pkg/password"
@@ -15,6 +17,7 @@ import (
// registerPageData 注册页面数据
type registerPageData struct {
Authorize
CSRFField template.HTML
Summary string
Email string
EmailMsg string
@@ -27,6 +30,7 @@ type registerPageData struct {
// loginPageData 登录页面数据
type loginPageData struct {
Authorize
CSRFField template.HTML
Summary string
Email string
EmailMsg string
@@ -40,14 +44,14 @@ type loginPageData struct {
func (server *Server) registerView(w http.ResponseWriter, r *http.Request) {
// 是否已经登录
server.isRedirect(w, r)
renderRegister(w, nil)
renderRegister(w, r, nil)
}
// loginView 登录页面
func (server *Server) loginView(w http.ResponseWriter, r *http.Request) {
// 是否已经登录
server.isRedirect(w, r)
renderLogin(w, nil)
renderLogin(w, r, nil)
}
// data
@@ -66,7 +70,7 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) {
password := r.PostFormValue("password")
resp, ok := viladatorRegister(email, username, password)
if !ok {
renderRegister(w, resp)
renderRegister(w, r, resp)
return
}
@@ -87,12 +91,12 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) {
if err != nil {
if server.store.IsUniqueViolation(err) {
resp.Summary = "邮箱或名称已经存在"
renderRegister(w, resp)
renderRegister(w, r, resp)
return
}
resp.Summary = "请求网络错误,请刷新重试"
renderRegister(w, resp)
renderRegister(w, r, resp)
return
}
@@ -104,7 +108,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := r.ParseForm(); err != nil {
renderLogin(w, registerPageData{Summary: "请求网络错误,请刷新重试"})
renderLogin(w, r, registerPageData{Summary: "请求网络错误,请刷新重试"})
return
}
@@ -112,7 +116,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
password := r.PostFormValue("password")
resp, ok := viladatorLogin(email, password)
if !ok {
renderLogin(w, resp)
renderLogin(w, r, resp)
return
}
@@ -121,26 +125,26 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
if err != nil {
if server.store.IsNoRows(sql.ErrNoRows) {
resp.Summary = "邮箱或密码错误"
renderLogin(w, resp)
renderLogin(w, r, resp)
return
}
resp.Summary = "请求网络错误,请刷新重试"
renderLogin(w, resp)
renderLogin(w, r, resp)
return
}
err = pwd.BcryptComparePassword(user.HashedPassword, password)
if err != nil {
resp.Summary = "邮箱或密码错误"
renderLogin(w, resp)
renderLogin(w, r, resp)
return
}
encoded, err := server.secureCookie.Encode(AuthorizeCookie, &Authorize{ID: user.ID, Name: user.Username})
if err != nil {
resp.Summary = "请求网络错误,请刷新重试(cookie)"
renderLogin(w, resp)
renderLogin(w, r, resp)
return
}
@@ -158,13 +162,29 @@ func (server *Server) logout(w http.ResponseWriter, r *http.Request) {
// method
// renderRegister 渲染注册页面
func renderRegister(w http.ResponseWriter, data any) {
renderLayout(w, data, "web/templates/user/register.html.tmpl")
func renderRegister(w http.ResponseWriter, r *http.Request, data any) {
if data != nil {
res := data.(registerPageData)
res.CSRFField = csrf.TemplateField(r)
renderLayout(w, res, "web/templates/user/register.html.tmpl")
} else {
renderLayout(w, registerPageData{
CSRFField: csrf.TemplateField(r),
}, "web/templates/user/register.html.tmpl")
}
}
// renderLogin 渲染登录页面
func renderLogin(w http.ResponseWriter, data any) {
renderLayout(w, data, "web/templates/user/login.html.tmpl")
func renderLogin(w http.ResponseWriter, r *http.Request, data any) {
if data != nil {
res := data.(loginPageData)
res.CSRFField = csrf.TemplateField(r)
renderLayout(w, res, "web/templates/user/login.html.tmpl")
} else {
renderLayout(w, loginPageData{
CSRFField: csrf.TemplateField(r),
}, "web/templates/user/login.html.tmpl")
}
}
// viladatorRegister 校验注册数据

View File

@@ -4,12 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"html/template"
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/zhang2092/mediahls/internal/db"
"github.com/zhang2092/mediahls/internal/pkg/convert"
@@ -28,12 +29,14 @@ type videoPageData struct {
// videosPageData 视频列表数据
type videosPageData struct {
Authorize
Videos []db.Video
CSRFField template.HTML
Videos []db.Video
}
// videoEditPageData 视频编辑数据
type videoEditPageData struct {
Authorize
CSRFField template.HTML
Summary string
ID string
IDMsg string
@@ -77,6 +80,7 @@ func (server *Server) videosView(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := videosPageData{
Authorize: withUser(ctx),
CSRFField: csrf.TemplateField(r),
}
vars := mux.Vars(r)
@@ -119,7 +123,7 @@ func (server *Server) editVideoView(w http.ResponseWriter, r *http.Request) {
vm.Status = int(v.Status)
}
}
renderEditVideo(w, vm)
renderEditVideo(w, r, vm)
}
// data
@@ -144,13 +148,13 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := r.ParseForm(); err != nil {
renderEditVideo(w, videoEditPageData{Summary: "请求网络错误, 请刷新重试"})
renderEditVideo(w, r, videoEditPageData{Summary: "请求网络错误, 请刷新重试"})
return
}
vm, ok := viladatorEditVedio(r)
if !ok {
renderEditVideo(w, vm)
renderEditVideo(w, r, vm)
return
}
@@ -170,14 +174,14 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
vm.Summary = "添加视频失败"
renderEditVideo(w, vm)
renderEditVideo(w, r, vm)
return
}
} else {
v, err := server.store.GetVideo(ctx, vm.ID)
if err != nil {
vm.Summary = "视频数据错误"
renderEditVideo(w, vm)
renderEditVideo(w, r, vm)
return
}
@@ -197,7 +201,7 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
vm.Summary = "更新视频失败"
renderEditVideo(w, vm)
renderEditVideo(w, r, vm)
return
}
}
@@ -215,7 +219,6 @@ func (server *Server) deleteVideo(w http.ResponseWriter, r *http.Request) {
return
}
log.Println(req.ID)
err := server.store.DeleteVideo(r.Context(), req.ID)
if err != nil {
RespondErr(w, "删除失败", nil)
@@ -284,8 +287,16 @@ func (server *Server) transfer(w http.ResponseWriter, r *http.Request) {
// method
// renderEditVideo 渲染视频编辑页面
func renderEditVideo(w http.ResponseWriter, data any) {
renderLayout(w, data, "web/templates/video/edit.html.tmpl")
func renderEditVideo(w http.ResponseWriter, r *http.Request, data any) {
if data != nil {
res := data.(videoEditPageData)
res.CSRFField = csrf.TemplateField(r)
renderLayout(w, res, "web/templates/video/edit.html.tmpl")
}
renderLayout(w, videoEditPageData{
CSRFField: csrf.TemplateField(r),
}, "web/templates/video/edit.html.tmpl")
}
// viladatorEditVedio 检验视频编辑数据