From 30c5ed2452aacdf400c10af551e35cde932dd224 Mon Sep 17 00:00:00 2001 From: kenneth Date: Mon, 4 Dec 2023 08:23:00 +0000 Subject: [PATCH] use embed render html/css/js ... --- internal/handlers/home.go | 2 +- internal/handlers/render.go | 4 ++-- internal/handlers/server.go | 14 +++++++++++--- internal/handlers/user.go | 30 +++++++++++++++--------------- internal/handlers/video.go | 20 ++++++++++---------- main.go | 31 ++++++++++++++++++++++++++++++- 6 files changed, 69 insertions(+), 32 deletions(-) diff --git a/internal/handlers/home.go b/internal/handlers/home.go index 74963e1..4529b5d 100644 --- a/internal/handlers/home.go +++ b/internal/handlers/home.go @@ -42,5 +42,5 @@ func (server *Server) homeView(w http.ResponseWriter, r *http.Request) { } } - renderLayout(w, r, data, "web/templates/home.html.tmpl") + server.renderLayout(w, r, data, "home.html.tmpl") } diff --git a/internal/handlers/render.go b/internal/handlers/render.go index 05c81fd..4e84556 100644 --- a/internal/handlers/render.go +++ b/internal/handlers/render.go @@ -26,7 +26,7 @@ import ( // } // renderLayout 渲染方法 带框架 -func renderLayout(w http.ResponseWriter, r *http.Request, data any, tmpl string) { +func (server *Server) renderLayout(w http.ResponseWriter, r *http.Request, data any, tmpl string) { t := template.New(filepath.Base(tmpl)) t = t.Funcs(template.FuncMap{ "csrfField": func() template.HTML { @@ -35,7 +35,7 @@ func renderLayout(w http.ResponseWriter, r *http.Request, data any, tmpl string) }) tpl := template.Must(t.Clone()) - tpl, err := tpl.ParseFiles(tmpl, "web/templates/base/header.html.tmpl", "web/templates/base/footer.html.tmpl") + tpl, err := tpl.ParseFS(server.templateFS, tmpl, "base/header.html.tmpl", "base/footer.html.tmpl") if err != nil { logger.Logger.Errorf("template parse: %s, %v", tmpl, err) w.WriteHeader(http.StatusInternalServerError) diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 7dad967..a87e08e 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "io/fs" "log" "net/http" "os" @@ -22,6 +23,10 @@ import ( ) type Server struct { + templateFS fs.FS + staticFS fs.FS + imgFS fs.FS + conf *config.Config router *mux.Router secureCookie *securecookie.SecureCookie @@ -30,7 +35,7 @@ type Server struct { tokenMaker token.Maker } -func NewServer(conf *config.Config, store db.Store) (*Server, error) { +func NewServer(templateFS fs.FS, staticFS fs.FS, imgFS fs.FS, conf *config.Config, store db.Store) (*Server, error) { tokenMaker, err := token.NewPasetoMaker(conf.TokenSymmetricKey) if err != nil { return nil, fmt.Errorf("cannot create token maker: %w", err) @@ -42,6 +47,9 @@ func NewServer(conf *config.Config, store db.Store) (*Server, error) { secureCookie.MaxAge(7200) server := &Server{ + templateFS: templateFS, + staticFS: staticFS, + imgFS: imgFS, conf: conf, secureCookie: secureCookie, store: store, @@ -55,8 +63,8 @@ func NewServer(conf *config.Config, store db.Store) (*Server, error) { func (server *Server) setupRouter() { router := mux.NewRouter() router.Use(mux.CORSMethodMiddleware(router)) - router.PathPrefix("/statics/").Handler(http.StripPrefix("/statics/", http.FileServer(http.Dir("web/statics")))) - router.PathPrefix("/upload/imgs").Handler(http.StripPrefix("/upload/imgs/", http.FileServer(http.Dir("upload/imgs")))) + router.PathPrefix("/statics/").Handler(http.StripPrefix("/statics/", http.FileServer(http.FS(server.staticFS)))) + router.PathPrefix("/upload/imgs").Handler(http.StripPrefix("/upload/imgs/", http.FileServer(http.FS(server.imgFS)))) csrfMiddleware := csrf.Protect( []byte(securecookie.GenerateRandomKey(32)), diff --git a/internal/handlers/user.go b/internal/handlers/user.go index 884c25b..3284501 100644 --- a/internal/handlers/user.go +++ b/internal/handlers/user.go @@ -40,14 +40,14 @@ type loginPageData struct { func (server *Server) registerView(w http.ResponseWriter, r *http.Request) { // 是否已经登录 server.isRedirect(w, r) - renderRegister(w, r, nil) + server.renderRegister(w, r, nil) } // loginView 登录页面 func (server *Server) loginView(w http.ResponseWriter, r *http.Request) { // 是否已经登录 server.isRedirect(w, r) - renderLogin(w, r, nil) + server.renderLogin(w, r, nil) } // data @@ -66,7 +66,7 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { password := r.PostFormValue("password") resp, ok := viladatorRegister(email, username, password) if !ok { - renderRegister(w, r, resp) + server.renderRegister(w, r, resp) return } @@ -87,12 +87,12 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { if err != nil { if server.store.IsUniqueViolation(err) { resp.Summary = "邮箱或名称已经存在" - renderRegister(w, r, resp) + server.renderRegister(w, r, resp) return } resp.Summary = "请求网络错误,请刷新重试" - renderRegister(w, r, resp) + server.renderRegister(w, r, resp) return } @@ -104,7 +104,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := r.ParseForm(); err != nil { - renderLogin(w, r, registerPageData{Summary: "请求网络错误,请刷新重试"}) + server.renderLogin(w, r, registerPageData{Summary: "请求网络错误,请刷新重试"}) return } @@ -112,7 +112,7 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { password := r.PostFormValue("password") resp, ok := viladatorLogin(email, password) if !ok { - renderLogin(w, r, resp) + server.renderLogin(w, r, resp) return } @@ -121,26 +121,26 @@ func (server *Server) login(w http.ResponseWriter, r *http.Request) { if err != nil { if server.store.IsNoRows(sql.ErrNoRows) { resp.Summary = "邮箱或密码错误" - renderLogin(w, r, resp) + server.renderLogin(w, r, resp) return } resp.Summary = "请求网络错误,请刷新重试" - renderLogin(w, r, resp) + server.renderLogin(w, r, resp) return } err = pwd.BcryptComparePassword(user.HashedPassword, password) if err != nil { resp.Summary = "邮箱或密码错误" - renderLogin(w, r, resp) + server.renderLogin(w, r, resp) return } encoded, err := server.secureCookie.Encode(AuthorizeCookie, &Authorize{ID: user.ID, Name: user.Username}) if err != nil { resp.Summary = "请求网络错误,请刷新重试(cookie)" - renderLogin(w, r, resp) + server.renderLogin(w, r, resp) return } @@ -158,13 +158,13 @@ func (server *Server) logout(w http.ResponseWriter, r *http.Request) { // method // renderRegister 渲染注册页面 -func renderRegister(w http.ResponseWriter, r *http.Request, data any) { - renderLayout(w, r, data, "web/templates/user/register.html.tmpl") +func (server *Server) renderRegister(w http.ResponseWriter, r *http.Request, data any) { + server.renderLayout(w, r, data, "user/register.html.tmpl") } // renderLogin 渲染登录页面 -func renderLogin(w http.ResponseWriter, r *http.Request, data any) { - renderLayout(w, r, data, "web/templates/user/login.html.tmpl") +func (server *Server) renderLogin(w http.ResponseWriter, r *http.Request, data any) { + server.renderLayout(w, r, data, "user/login.html.tmpl") } // viladatorRegister 校验注册数据 diff --git a/internal/handlers/video.go b/internal/handlers/video.go index 0deba56..be41b5c 100644 --- a/internal/handlers/video.go +++ b/internal/handlers/video.go @@ -68,7 +68,7 @@ func (server *Server) videoView(w http.ResponseWriter, r *http.Request) { if err == nil { data.Authorize = *auth } - renderLayout(w, r, data, "web/templates/video/play.html.tmpl") + server.renderLayout(w, r, data, "video/play.html.tmpl") } // videosView 视频列表页面 @@ -98,7 +98,7 @@ func (server *Server) videosView(w http.ResponseWriter, r *http.Request) { } } - renderLayout(w, r, data, "web/templates/video/videos.html.tmpl") + server.renderLayout(w, r, data, "video/videos.html.tmpl") } // editVideoView 视频编辑页面 @@ -118,7 +118,7 @@ func (server *Server) editVideoView(w http.ResponseWriter, r *http.Request) { vm.Status = int(v.Status) } } - renderEditVideo(w, r, vm) + server.renderEditVideo(w, r, vm) } // data @@ -143,13 +143,13 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := r.ParseForm(); err != nil { - renderEditVideo(w, r, videoEditPageData{Summary: "请求网络错误, 请刷新重试"}) + server.renderEditVideo(w, r, videoEditPageData{Summary: "请求网络错误, 请刷新重试"}) return } vm, ok := viladatorEditVedio(r) if !ok { - renderEditVideo(w, r, vm) + server.renderEditVideo(w, r, vm) return } @@ -169,14 +169,14 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { }) if err != nil { vm.Summary = "添加视频失败" - renderEditVideo(w, r, vm) + server.renderEditVideo(w, r, vm) return } } else { v, err := server.store.GetVideo(ctx, vm.ID) if err != nil { vm.Summary = "视频数据错误" - renderEditVideo(w, r, vm) + server.renderEditVideo(w, r, vm) return } @@ -196,7 +196,7 @@ func (server *Server) editVideo(w http.ResponseWriter, r *http.Request) { }) if err != nil { vm.Summary = "更新视频失败" - renderEditVideo(w, r, vm) + server.renderEditVideo(w, r, vm) return } } @@ -282,8 +282,8 @@ func (server *Server) transfer(w http.ResponseWriter, r *http.Request) { // method // renderEditVideo 渲染视频编辑页面 -func renderEditVideo(w http.ResponseWriter, r *http.Request, data any) { - renderLayout(w, r, data, "web/templates/video/edit.html.tmpl") +func (server *Server) renderEditVideo(w http.ResponseWriter, r *http.Request, data any) { + server.renderLayout(w, r, data, "video/edit.html.tmpl") } // viladatorEditVedio 检验视频编辑数据 diff --git a/main.go b/main.go index 6fa88a3..8a606c5 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,8 @@ package main import ( "database/sql" + "embed" + "io/fs" "log" "github.com/zhang2092/mediahls/internal/db" @@ -10,11 +12,38 @@ import ( "github.com/zhang2092/mediahls/internal/pkg/logger" ) +//go:embed web/templates +var templateFS embed.FS + +//go:embed web/statics +var staticFS embed.FS + +//go:embed upload/imgs +var imgFS embed.FS + func main() { // filename, _ := nanoId.Nanoid() // log.Println(filename) // return + // Set up templates + templates, err := fs.Sub(templateFS, "web/templates") + if err != nil { + log.Fatal(err) + } + + // Set up statics + statics, err := fs.Sub(staticFS, "web/statics") + if err != nil { + log.Fatal(err) + } + + // Set up imgs + imgs, err := fs.Sub(imgFS, "upload/imgs") + if err != nil { + log.Fatal(err) + } + config, err := config.LoadConfig(".") if err != nil { log.Fatal("cannot load config: ", err) @@ -28,7 +57,7 @@ func main() { } store := db.NewStore(conn) - server, err := handlers.NewServer(config, store) + server, err := handlers.NewServer(templates, statics, imgs, config, store) if err != nil { log.Fatal("cannot create server: ", err) }