add csrf check
This commit is contained in:
parent
8f89c19e12
commit
91edab2f9b
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gorilla/csrf v1.7.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@ -21,6 +21,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI=
|
||||
github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/zhang2092/mediahls/internal/db"
|
||||
@ -57,6 +58,13 @@ func (server *Server) setupRouter() {
|
||||
router.PathPrefix("/statics/").Handler(http.StripPrefix("/statics/", http.FileServer(http.Dir("web/statics"))))
|
||||
router.PathPrefix("/upload/imgs").Handler(http.StripPrefix("/upload/imgs/", http.FileServer(http.Dir("upload/imgs"))))
|
||||
|
||||
csrfMiddleware := csrf.Protect(
|
||||
[]byte(securecookie.GenerateRandomKey(32)),
|
||||
csrf.Secure(false),
|
||||
csrf.HttpOnly(true),
|
||||
)
|
||||
router.Use(csrfMiddleware)
|
||||
|
||||
router.HandleFunc("/register", server.registerView).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register", server.register).Methods(http.MethodPost)
|
||||
router.HandleFunc("/login", server.loginView).Methods(http.MethodGet)
|
||||
|
||||
@ -2,9 +2,11 @@ package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/zhang2092/mediahls/internal/db"
|
||||
"github.com/zhang2092/mediahls/internal/pkg/cookie"
|
||||
pwd "github.com/zhang2092/mediahls/internal/pkg/password"
|
||||
@ -15,6 +17,7 @@ import (
|
||||
// registerPageData 注册页面数据
|
||||
type registerPageData struct {
|
||||
Authorize
|
||||
CSRFField template.HTML
|
||||
Summary string
|
||||
Email string
|
||||
EmailMsg string
|
||||
@ -27,6 +30,7 @@ type registerPageData struct {
|
||||
// loginPageData 登录页面数据
|
||||
type loginPageData struct {
|
||||
Authorize
|
||||
CSRFField template.HTML
|
||||
Summary string
|
||||
Email string
|
||||
EmailMsg string
|
||||
@ -40,14 +44,14 @@ type loginPageData struct {
|
||||
func (server *Server) registerView(w http.ResponseWriter, r *http.Request) {
|
||||
// 是否已经登录
|
||||
server.isRedirect(w, r)
|
||||
renderRegister(w, nil)
|
||||
renderRegister(w, r, nil)
|
||||
}
|
||||
|
||||
// loginView 登录页面
|
||||
func (server *Server) loginView(w http.ResponseWriter, r *http.Request) {
|
||||
// 是否已经登录
|
||||
server.isRedirect(w, r)
|
||||
renderLogin(w, nil)
|
||||
renderLogin(w, r, nil)
|
||||
}
|
||||
|
||||
// data
|
||||
@ -66,7 +70,7 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) {
|
||||
password := r.PostFormValue("password")
|
||||
resp, ok := viladatorRegister(email, username, password)
|
||||
if !ok {
|
||||
renderRegister(w, resp)
|
||||
renderRegister(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -87,12 +91,12 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
if server.store.IsUniqueViolation(err) {
|
||||
resp.Summary = "邮箱或名称已经存在"
|
||||
renderRegister(w, resp)
|
||||
renderRegister(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
resp.Summary = "请求网络错误,请刷新重试"
|
||||
renderRegister(w, resp)
|
||||
renderRegister(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -104,7 +108,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
renderLogin(w, registerPageData{Summary: "请求网络错误,请刷新重试"})
|
||||
renderLogin(w, r, registerPageData{Summary: "请求网络错误,请刷新重试"})
|
||||
return
|
||||
}
|
||||
|
||||
@ -112,7 +116,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
|
||||
password := r.PostFormValue("password")
|
||||
resp, ok := viladatorLogin(email, password)
|
||||
if !ok {
|
||||
renderLogin(w, resp)
|
||||
renderLogin(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -121,26 +125,26 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
if server.store.IsNoRows(sql.ErrNoRows) {
|
||||
resp.Summary = "邮箱或密码错误"
|
||||
renderLogin(w, resp)
|
||||
renderLogin(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
resp.Summary = "请求网络错误,请刷新重试"
|
||||
renderLogin(w, resp)
|
||||
renderLogin(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
err = pwd.BcryptComparePassword(user.HashedPassword, password)
|
||||
if err != nil {
|
||||
resp.Summary = "邮箱或密码错误"
|
||||
renderLogin(w, resp)
|
||||
renderLogin(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
encoded, err := server.secureCookie.Encode(AuthorizeCookie, &Authorize{ID: user.ID, Name: user.Username})
|
||||
if err != nil {
|
||||
resp.Summary = "请求网络错误,请刷新重试(cookie)"
|
||||
renderLogin(w, resp)
|
||||
renderLogin(w, r, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -158,13 +162,29 @@ func (server *Server) logout(w http.ResponseWriter, r *http.Request) {
|
||||
// method
|
||||
|
||||
// renderRegister 渲染注册页面
|
||||
func renderRegister(w http.ResponseWriter, data any) {
|
||||
renderLayout(w, data, "web/templates/user/register.html.tmpl")
|
||||
func renderRegister(w http.ResponseWriter, r *http.Request, data any) {
|
||||
if data != nil {
|
||||
res := data.(registerPageData)
|
||||
res.CSRFField = csrf.TemplateField(r)
|
||||
renderLayout(w, res, "web/templates/user/register.html.tmpl")
|
||||
} else {
|
||||
renderLayout(w, registerPageData{
|
||||
CSRFField: csrf.TemplateField(r),
|
||||
}, "web/templates/user/register.html.tmpl")
|
||||
}
|
||||
}
|
||||
|
||||
// renderLogin 渲染登录页面
|
||||
func renderLogin(w http.ResponseWriter, data any) {
|
||||
renderLayout(w, data, "web/templates/user/login.html.tmpl")
|
||||
func renderLogin(w http.ResponseWriter, r *http.Request, data any) {
|
||||
if data != nil {
|
||||
res := data.(loginPageData)
|
||||
res.CSRFField = csrf.TemplateField(r)
|
||||
renderLayout(w, res, "web/templates/user/login.html.tmpl")
|
||||
} else {
|
||||
renderLayout(w, loginPageData{
|
||||
CSRFField: csrf.TemplateField(r),
|
||||
}, "web/templates/user/login.html.tmpl")
|
||||
}
|
||||
}
|
||||
|
||||
// viladatorRegister 校验注册数据
|
||||
|
||||
@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/zhang2092/mediahls/internal/db"
|
||||
"github.com/zhang2092/mediahls/internal/pkg/convert"
|
||||
@ -28,12 +29,14 @@ type videoPageData struct {
|
||||
// videosPageData 视频列表数据
|
||||
type videosPageData struct {
|
||||
Authorize
|
||||
Videos []db.Video
|
||||
CSRFField template.HTML
|
||||
Videos []db.Video
|
||||
}
|
||||
|
||||
// videoEditPageData 视频编辑数据
|
||||
type videoEditPageData struct {
|
||||
Authorize
|
||||
CSRFField template.HTML
|
||||
Summary string
|
||||
ID string
|
||||
IDMsg string
|
||||
@ -77,6 +80,7 @@ func (server *Server) videosView(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data := videosPageData{
|
||||
Authorize: withUser(ctx),
|
||||
CSRFField: csrf.TemplateField(r),
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
@ -119,7 +123,7 @@ func (server *Server) editVideoView(w http.ResponseWriter, r *http.Request) {
|
||||
vm.Status = int(v.Status)
|
||||
}
|
||||
}
|
||||
renderEditVideo(w, vm)
|
||||
renderEditVideo(w, r, vm)
|
||||
}
|
||||
|
||||
// data
|
||||
@ -144,13 +148,13 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
renderEditVideo(w, videoEditPageData{Summary: "请求网络错误, 请刷新重试"})
|
||||
renderEditVideo(w, r, videoEditPageData{Summary: "请求网络错误, 请刷新重试"})
|
||||
return
|
||||
}
|
||||
|
||||
vm, ok := viladatorEditVedio(r)
|
||||
if !ok {
|
||||
renderEditVideo(w, vm)
|
||||
renderEditVideo(w, r, vm)
|
||||
return
|
||||
}
|
||||
|
||||
@ -170,14 +174,14 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
if err != nil {
|
||||
vm.Summary = "添加视频失败"
|
||||
renderEditVideo(w, vm)
|
||||
renderEditVideo(w, r, vm)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
v, err := server.store.GetVideo(ctx, vm.ID)
|
||||
if err != nil {
|
||||
vm.Summary = "视频数据错误"
|
||||
renderEditVideo(w, vm)
|
||||
renderEditVideo(w, r, vm)
|
||||
return
|
||||
}
|
||||
|
||||
@ -197,7 +201,7 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
if err != nil {
|
||||
vm.Summary = "更新视频失败"
|
||||
renderEditVideo(w, vm)
|
||||
renderEditVideo(w, r, vm)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -215,7 +219,6 @@ func (server *Server) deleteVideo(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Println(req.ID)
|
||||
err := server.store.DeleteVideo(r.Context(), req.ID)
|
||||
if err != nil {
|
||||
RespondErr(w, "删除失败", nil)
|
||||
@ -284,8 +287,16 @@ func (server *Server) transfer(w http.ResponseWriter, r *http.Request) {
|
||||
// method
|
||||
|
||||
// renderEditVideo 渲染视频编辑页面
|
||||
func renderEditVideo(w http.ResponseWriter, data any) {
|
||||
renderLayout(w, data, "web/templates/video/edit.html.tmpl")
|
||||
func renderEditVideo(w http.ResponseWriter, r *http.Request, data any) {
|
||||
if data != nil {
|
||||
res := data.(videoEditPageData)
|
||||
res.CSRFField = csrf.TemplateField(r)
|
||||
renderLayout(w, res, "web/templates/video/edit.html.tmpl")
|
||||
}
|
||||
|
||||
renderLayout(w, videoEditPageData{
|
||||
CSRFField: csrf.TemplateField(r),
|
||||
}, "web/templates/video/edit.html.tmpl")
|
||||
}
|
||||
|
||||
// viladatorEditVedio 检验视频编辑数据
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
<h1>登录</h1>
|
||||
<div class="col-sm-4 py-md-5">
|
||||
<form action="/login" method="post">
|
||||
{{ .CSRFField }}
|
||||
<div class="form-group">
|
||||
<div class="input-group">
|
||||
<div class="input-group-prepend">
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
<h1>注册</h1>
|
||||
<div class="col-sm-4 py-md-5">
|
||||
<form action="/register" method="post">
|
||||
{{ .CSRFField }}
|
||||
<div class="form-group">
|
||||
<div class="input-group">
|
||||
<div class="input-group-prepend">
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
</div>
|
||||
<div class="col-sm-6 py-md-5 flex flex-column justify-content">
|
||||
<form action="/me/videos/update" method="post">
|
||||
{{ .CSRFField }}
|
||||
{{if .ID}}
|
||||
<div class="form-group">
|
||||
<div class="input-group">
|
||||
|
||||
@ -35,6 +35,9 @@
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
<div class="hidden">
|
||||
{{ .CSRFField }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@ -44,9 +47,13 @@
|
||||
let that = $(this)
|
||||
that.attr("disable", true).html('转码中...')
|
||||
let id = that.attr("data-id")
|
||||
let csrfToken = $('input[name="gorilla.csrf.Token"]').val()
|
||||
$.ajax({
|
||||
url: '/transfer/' + id,
|
||||
type: 'post',
|
||||
headers: {
|
||||
"X-CSRF-Token": csrfToken
|
||||
},
|
||||
success: function (obj) {
|
||||
$('#msg').html(obj)
|
||||
},
|
||||
@ -59,9 +66,13 @@
|
||||
let that = $(this)
|
||||
that.attr("disable", true).html('删除中...')
|
||||
let id = that.attr("data-id")
|
||||
let csrfToken = $('input[name="gorilla.csrf.Token"]').val()
|
||||
$.ajax({
|
||||
url: '/me/videos/delete',
|
||||
type: 'post',
|
||||
headers: {
|
||||
"X-CSRF-Token": csrfToken
|
||||
},
|
||||
contentType: 'application/json',
|
||||
dataType: 'json',
|
||||
data:JSON.stringify({"id": id}),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user