diff --git a/internal/handlers/base.go b/internal/handlers/base.go index 1035537..2332375 100644 --- a/internal/handlers/base.go +++ b/internal/handlers/base.go @@ -4,15 +4,14 @@ import ( "net/http" "github.com/rs/xid" - "github.com/zhang2092/mediahls/internal/pkg/cookie" ) const ( - AuthorizeCookie = "authorize" - ContextUser CtxTypeUser = "context_user" + AuthorizeCookie = "authorize" + ContextUser ctxKey = "context_user" ) -type CtxTypeUser string +type ctxKey string type Authorize struct { ID string `json:"id"` @@ -25,13 +24,10 @@ func genId() string { } func (server *Server) isRedirect(w http.ResponseWriter, r *http.Request) { - _, err := server.withCookie(r) - if err != nil { - // 1. 删除cookie - cookie.DeleteCookie(w, cookie.AuthorizeName) + u := withUser(r.Context()) + if u != nil { + // 已经登录, 直接到首页 + http.Redirect(w, r, "/", http.StatusFound) return } - - // cookie 校验成功 - http.Redirect(w, r, "/", http.StatusFound) } diff --git a/internal/handlers/home.go b/internal/handlers/home.go index 4529b5d..088fcba 100644 --- a/internal/handlers/home.go +++ b/internal/handlers/home.go @@ -8,25 +8,12 @@ import ( "github.com/zhang2092/mediahls/internal/db" ) -// obj - -// homePageData 首页数据 -type homePageData struct { - Authorize - Videos []db.Video -} - // view // home 首页 func (server *Server) homeView(w http.ResponseWriter, r *http.Request) { - data := homePageData{} - auth, err := server.withCookie(r) - if err == nil { - data.Authorize = *auth - } - ctx := r.Context() + var result []db.Video videos, err := server.store.ListVideos(ctx, db.ListVideosParams{ Limit: 100, Offset: 0, @@ -38,9 +25,9 @@ func (server *Server) homeView(w http.ResponseWriter, r *http.Request) { item.Description = temp log.Println(item.Description) } - data.Videos = append(data.Videos, item) + result = append(result, item) } } - server.renderLayout(w, r, data, "home.html.tmpl") + server.renderLayout(w, r, result, "home.html.tmpl") } diff --git a/internal/handlers/middleware.go b/internal/handlers/middleware.go index 781a2f4..bcb4efd 100644 --- a/internal/handlers/middleware.go +++ b/internal/handlers/middleware.go @@ -2,57 +2,47 @@ package handlers import ( "context" - "encoding/json" - "log" "net/http" - - "github.com/zhang2092/mediahls/internal/pkg/convert" ) -func (server *Server) authorizeMiddleware(next http.Handler) http.Handler { +func (server *Server) authorize(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - u, err := server.withCookie(r) - if err != nil { + u := withUser(r.Context()) + if u == nil { http.Redirect(w, r, "/login", http.StatusFound) return } - b, err := json.Marshal(u) + next.ServeHTTP(w, r) + }) +} + +func (server *Server) setUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(AuthorizeCookie) if err != nil { - log.Printf("json marshal authorize user: %v", err) - http.Redirect(w, r, "/login", http.StatusFound) + next.ServeHTTP(w, r) + return + } + + u := Authorize{} + err = server.secureCookie.Decode(AuthorizeCookie, cookie.Value, &u) + if err != nil { + next.ServeHTTP(w, r) return } ctx := r.Context() - ctx = context.WithValue(ctx, ContextUser, b) + ctx = context.WithValue(ctx, ContextUser, u) next.ServeHTTP(w, r.WithContext(ctx)) }) } -func (server *Server) withCookie(r *http.Request) (*Authorize, error) { - cookie, err := r.Cookie(AuthorizeCookie) - if err != nil { - return nil, err +func withUser(ctx context.Context) *Authorize { + val := ctx.Value(ContextUser) + if u, ok := val.(Authorize); ok { + return &u } - u := &Authorize{} - err = server.secureCookie.Decode(AuthorizeCookie, cookie.Value, u) - if err != nil { - // log.Printf("secure decode cookie: %v", err) - return nil, err - } - - return u, nil -} - -func withUser(ctx context.Context) Authorize { - var result Authorize - ctxValue, err := convert.ToByteE(ctx.Value(ContextUser)) - if err != nil { - return result - } - - json.Unmarshal(ctxValue, &result) - return result + return nil } diff --git a/internal/handlers/render.go b/internal/handlers/render.go index 1053e6b..45e9fe1 100644 --- a/internal/handlers/render.go +++ b/internal/handlers/render.go @@ -32,6 +32,9 @@ func (server *Server) renderLayout(w http.ResponseWriter, r *http.Request, data "csrfField": func() template.HTML { return csrf.TemplateField(r) }, + "currentUser": func() *Authorize { + return withUser(r.Context()) + }, }) tpl := template.Must(t.Clone()) diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 741f9cf..7131e00 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -75,6 +75,7 @@ func (server *Server) setupRouter() { csrf.CookieName("authorize_csrf"), ) router.Use(csrfMiddleware) + router.Use(server.setUser) router.Handle("/register", hds.MethodHandler{ http.MethodGet: http.HandlerFunc(server.registerView), @@ -93,7 +94,7 @@ func (server *Server) setupRouter() { router.HandleFunc("/media/{xid}/stream/{segName:index[0-9]+.ts}", server.stream).Methods(http.MethodGet) subRouter := router.PathPrefix("/").Subrouter() - subRouter.Use(server.authorizeMiddleware) + subRouter.Use(server.authorize) subRouter.HandleFunc("/me/videos", server.videosView).Methods(http.MethodGet) subRouter.HandleFunc("/me/videos/p{page}", server.videosView).Methods(http.MethodGet) diff --git a/internal/handlers/user.go b/internal/handlers/user.go index 3284501..432142a 100644 --- a/internal/handlers/user.go +++ b/internal/handlers/user.go @@ -14,7 +14,6 @@ import ( // registerPageData 注册页面数据 type registerPageData struct { - Authorize Summary string Email string EmailMsg string @@ -26,7 +25,6 @@ type registerPageData struct { // loginPageData 登录页面数据 type loginPageData struct { - Authorize Summary string Email string EmailMsg string diff --git a/internal/handlers/video.go b/internal/handlers/video.go index be41b5c..d020c31 100644 --- a/internal/handlers/video.go +++ b/internal/handlers/video.go @@ -18,21 +18,8 @@ import ( // obj -// videoPageData 播放页面数据 -type videoPageData struct { - Authorize - Video db.Video -} - -// videosPageData 视频列表数据 -type videosPageData struct { - Authorize - Videos []db.Video -} - // videoEditPageData 视频编辑数据 type videoEditPageData struct { - Authorize Summary string ID string IDMsg string @@ -59,32 +46,23 @@ type videoDeleteRequest struct { func (server *Server) videoView(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) xid := vars["xid"] - data := videoPageData{} - video, err := server.store.GetVideo(r.Context(), xid) - if err == nil { - data.Video = video - } - auth, err := server.withCookie(r) - if err == nil { - data.Authorize = *auth - } - server.renderLayout(w, r, data, "video/play.html.tmpl") + result, _ := server.store.GetVideo(r.Context(), xid) + server.renderLayout(w, r, result, "video/play.html.tmpl") } // videosView 视频列表页面 func (server *Server) videosView(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - data := videosPageData{ - Authorize: withUser(ctx), - } - vars := mux.Vars(r) page, err := strconv.Atoi(vars["page"]) if err != nil { page = 1 } + + ctx := r.Context() + u := withUser(ctx) + var result []db.Video videos, err := server.store.ListVideosByUser(ctx, db.ListVideosByUserParams{ - UserID: data.Authorize.ID, + UserID: u.ID, Limit: 16, Offset: int32((page - 1) * 16), }) @@ -94,28 +72,30 @@ func (server *Server) videosView(w http.ResponseWriter, r *http.Request) { temp := strings.TrimSpace(item.Description[0:65]) + "..." item.Description = temp } - data.Videos = append(data.Videos, item) + result = append(result, item) } } - server.renderLayout(w, r, data, "video/videos.html.tmpl") + server.renderLayout(w, r, result, "video/videos.html.tmpl") } // editVideoView 视频编辑页面 func (server *Server) editVideoView(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) xid := vars["xid"] - vm := videoEditPageData{ - Authorize: withUser(r.Context()), - } + ctx := r.Context() + u := withUser(ctx) + vm := videoEditPageData{} if len(xid) > 0 { - if v, err := server.store.GetVideo(r.Context(), xid); err == nil { - vm.ID = v.ID - vm.Title = v.Title - vm.Images = v.Images - vm.Description = v.Description - vm.OriginLink = v.OriginLink - vm.Status = int(v.Status) + if v, err := server.store.GetVideo(ctx, xid); err == nil { + if u.ID == v.UserID { + vm.ID = v.ID + vm.Title = v.Title + vm.Images = v.Images + vm.Description = v.Description + vm.OriginLink = v.OriginLink + vm.Status = int(v.Status) + } } } server.renderEditVideo(w, r, vm) @@ -291,7 +271,6 @@ func viladatorEditVedio(r *http.Request) (videoEditPageData, bool) { ok := true status, _ := strconv.Atoi(r.PostFormValue("status")) resp := videoEditPageData{ - Authorize: withUser(r.Context()), ID: r.PostFormValue("id"), Title: r.PostFormValue("title"), Images: r.PostFormValue("images"), diff --git a/web/templates/base/header.html.tmpl b/web/templates/base/header.html.tmpl index b110741..d7c125b 100644 --- a/web/templates/base/header.html.tmpl +++ b/web/templates/base/header.html.tmpl @@ -21,9 +21,9 @@ HLS流媒体