diff --git a/go.mod b/go.mod index 79df7b8..e07f92f 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gorilla/csrf v1.7.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect diff --git a/go.sum b/go.sum index 7b0388f..71a74e3 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= +github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 4211621..7dad967 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/gorilla/csrf" "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/zhang2092/mediahls/internal/db" @@ -57,6 +58,13 @@ func (server *Server) setupRouter() { router.PathPrefix("/statics/").Handler(http.StripPrefix("/statics/", http.FileServer(http.Dir("web/statics")))) router.PathPrefix("/upload/imgs").Handler(http.StripPrefix("/upload/imgs/", http.FileServer(http.Dir("upload/imgs")))) + csrfMiddleware := csrf.Protect( + []byte(securecookie.GenerateRandomKey(32)), + csrf.Secure(false), + csrf.HttpOnly(true), + ) + router.Use(csrfMiddleware) + router.HandleFunc("/register", server.registerView).Methods(http.MethodGet) router.HandleFunc("/register", server.register).Methods(http.MethodPost) router.HandleFunc("/login", server.loginView).Methods(http.MethodGet) diff --git a/internal/handlers/user.go b/internal/handlers/user.go index 9c61324..2c38618 100644 --- a/internal/handlers/user.go +++ b/internal/handlers/user.go @@ -2,9 +2,11 @@ package handlers import ( "database/sql" + "html/template" "net/http" "time" + "github.com/gorilla/csrf" "github.com/zhang2092/mediahls/internal/db" "github.com/zhang2092/mediahls/internal/pkg/cookie" pwd "github.com/zhang2092/mediahls/internal/pkg/password" @@ -15,6 +17,7 @@ import ( // registerPageData 注册页面数据 type registerPageData struct { Authorize + CSRFField template.HTML Summary string Email string EmailMsg string @@ -27,6 +30,7 @@ type registerPageData struct { // loginPageData 登录页面数据 type loginPageData struct { Authorize + CSRFField template.HTML Summary string Email string EmailMsg string @@ -40,14 +44,14 @@ type loginPageData struct { func (server *Server) registerView(w http.ResponseWriter, r *http.Request) { // 是否已经登录 server.isRedirect(w, r) - renderRegister(w, nil) + renderRegister(w, r, nil) } // loginView 登录页面 func (server *Server) loginView(w http.ResponseWriter, r *http.Request) { // 是否已经登录 server.isRedirect(w, r) - renderLogin(w, nil) + renderLogin(w, r, nil) } // data @@ -66,7 +70,7 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { password := r.PostFormValue("password") resp, ok := viladatorRegister(email, username, password) if !ok { - renderRegister(w, resp) + renderRegister(w, r, resp) return } @@ -87,12 +91,12 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { if err != nil { if server.store.IsUniqueViolation(err) { resp.Summary = "邮箱或名称已经存在" - renderRegister(w, resp) + renderRegister(w, r, resp) return } resp.Summary = "请求网络错误,请刷新重试" - renderRegister(w, resp) + renderRegister(w, r, resp) return } @@ -104,7 +108,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := r.ParseForm(); err != nil { - renderLogin(w, registerPageData{Summary: "请求网络错误,请刷新重试"}) + renderLogin(w, r, registerPageData{Summary: "请求网络错误,请刷新重试"}) return } @@ -112,7 +116,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { password := r.PostFormValue("password") resp, ok := viladatorLogin(email, password) if !ok { - renderLogin(w, resp) + renderLogin(w, r, resp) return } @@ -121,26 +125,26 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { if err != nil { if server.store.IsNoRows(sql.ErrNoRows) { resp.Summary = "邮箱或密码错误" - renderLogin(w, resp) + renderLogin(w, r, resp) return } resp.Summary = "请求网络错误,请刷新重试" - renderLogin(w, resp) + renderLogin(w, r, resp) return } err = pwd.BcryptComparePassword(user.HashedPassword, password) if err != nil { resp.Summary = "邮箱或密码错误" - renderLogin(w, resp) + renderLogin(w, r, resp) return } encoded, err := server.secureCookie.Encode(AuthorizeCookie, &Authorize{ID: user.ID, Name: user.Username}) if err != nil { resp.Summary = "请求网络错误,请刷新重试(cookie)" - renderLogin(w, resp) + renderLogin(w, r, resp) return } @@ -158,13 +162,29 @@ func (server *Server) logout(w http.ResponseWriter, r *http.Request) { // method // renderRegister 渲染注册页面 -func renderRegister(w http.ResponseWriter, data any) { - renderLayout(w, data, "web/templates/user/register.html.tmpl") +func renderRegister(w http.ResponseWriter, r *http.Request, data any) { + if data != nil { + res := data.(registerPageData) + res.CSRFField = csrf.TemplateField(r) + renderLayout(w, res, "web/templates/user/register.html.tmpl") + } else { + renderLayout(w, registerPageData{ + CSRFField: csrf.TemplateField(r), + }, "web/templates/user/register.html.tmpl") + } } // renderLogin 渲染登录页面 -func renderLogin(w http.ResponseWriter, data any) { - renderLayout(w, data, "web/templates/user/login.html.tmpl") +func renderLogin(w http.ResponseWriter, r *http.Request, data any) { + if data != nil { + res := data.(loginPageData) + res.CSRFField = csrf.TemplateField(r) + renderLayout(w, res, "web/templates/user/login.html.tmpl") + } else { + renderLayout(w, loginPageData{ + CSRFField: csrf.TemplateField(r), + }, "web/templates/user/login.html.tmpl") + } } // viladatorRegister 校验注册数据 diff --git a/internal/handlers/video.go b/internal/handlers/video.go index fdbc9d8..54d966d 100644 --- a/internal/handlers/video.go +++ b/internal/handlers/video.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "fmt" - "log" + "html/template" "net/http" "strconv" "strings" "time" + "github.com/gorilla/csrf" "github.com/gorilla/mux" "github.com/zhang2092/mediahls/internal/db" "github.com/zhang2092/mediahls/internal/pkg/convert" @@ -28,12 +29,14 @@ type videoPageData struct { // videosPageData 视频列表数据 type videosPageData struct { Authorize - Videos []db.Video + CSRFField template.HTML + Videos []db.Video } // videoEditPageData 视频编辑数据 type videoEditPageData struct { Authorize + CSRFField template.HTML Summary string ID string IDMsg string @@ -77,6 +80,7 @@ func (server *Server) videosView(w http.ResponseWriter, r *http.Request) { ctx := r.Context() data := videosPageData{ Authorize: withUser(ctx), + CSRFField: csrf.TemplateField(r), } vars := mux.Vars(r) @@ -119,7 +123,7 @@ func (server *Server) editVideoView(w http.ResponseWriter, r *http.Request) { vm.Status = int(v.Status) } } - renderEditVideo(w, vm) + renderEditVideo(w, r, vm) } // data @@ -144,13 +148,13 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := r.ParseForm(); err != nil { - renderEditVideo(w, videoEditPageData{Summary: "请求网络错误, 请刷新重试"}) + renderEditVideo(w, r, videoEditPageData{Summary: "请求网络错误, 请刷新重试"}) return } vm, ok := viladatorEditVedio(r) if !ok { - renderEditVideo(w, vm) + renderEditVideo(w, r, vm) return } @@ -170,14 +174,14 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { }) if err != nil { vm.Summary = "添加视频失败" - renderEditVideo(w, vm) + renderEditVideo(w, r, vm) return } } else { v, err := server.store.GetVideo(ctx, vm.ID) if err != nil { vm.Summary = "视频数据错误" - renderEditVideo(w, vm) + renderEditVideo(w, r, vm) return } @@ -197,7 +201,7 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { }) if err != nil { vm.Summary = "更新视频失败" - renderEditVideo(w, vm) + renderEditVideo(w, r, vm) return } } @@ -215,7 +219,6 @@ func (server *Server) deleteVideo(w http.ResponseWriter, r *http.Request) { return } - log.Println(req.ID) err := server.store.DeleteVideo(r.Context(), req.ID) if err != nil { RespondErr(w, "删除失败", nil) @@ -284,8 +287,16 @@ func (server *Server) transfer(w http.ResponseWriter, r *http.Request) { // method // renderEditVideo 渲染视频编辑页面 -func renderEditVideo(w http.ResponseWriter, data any) { - renderLayout(w, data, "web/templates/video/edit.html.tmpl") +func renderEditVideo(w http.ResponseWriter, r *http.Request, data any) { + if data != nil { + res := data.(videoEditPageData) + res.CSRFField = csrf.TemplateField(r) + renderLayout(w, res, "web/templates/video/edit.html.tmpl") + } + + renderLayout(w, videoEditPageData{ + CSRFField: csrf.TemplateField(r), + }, "web/templates/video/edit.html.tmpl") } // viladatorEditVedio 检验视频编辑数据 diff --git a/web/templates/user/login.html.tmpl b/web/templates/user/login.html.tmpl index 6cc2225..74da05b 100644 --- a/web/templates/user/login.html.tmpl +++ b/web/templates/user/login.html.tmpl @@ -4,6 +4,7 @@