diff --git a/api/subscription.go b/api/subscription.go new file mode 100644 index 00000000..da5034c4 --- /dev/null +++ b/api/subscription.go @@ -0,0 +1,298 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/LiterMC/go-openbmclapi/utils" +) + +type SubscriptionManager interface { + GetWebPushKey() string + + GetSubscribe(user string, client string) (*SubscribeRecord, error) + SetSubscribe(SubscribeRecord) error + RemoveSubscribe(user string, client string) error + ForEachSubscribe(cb func(*SubscribeRecord) error) error + + GetEmailSubscription(user string, addr string) (*EmailSubscriptionRecord, error) + AddEmailSubscription(EmailSubscriptionRecord) error + UpdateEmailSubscription(EmailSubscriptionRecord) error + RemoveEmailSubscription(user string, addr string) error + ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) error) error + ForEachUsersEmailSubscription(user string, cb func(*EmailSubscriptionRecord) error) error + ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecord) error) error + + GetWebhook(user string, id uuid.UUID) (*WebhookRecord, error) + AddWebhook(WebhookRecord) error + UpdateWebhook(WebhookRecord) error + UpdateEnableWebhook(user string, id uuid.UUID, enabled bool) error + RemoveWebhook(user string, id uuid.UUID) error + ForEachWebhook(cb func(*WebhookRecord) error) error + ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) error + ForEachEnabledWebhook(cb func(*WebhookRecord) error) error +} + +type SubscribeRecord struct { + User string `json:"user"` + Client string `json:"client"` + EndPoint string `json:"endpoint"` + Keys SubscribeRecordKeys `json:"keys"` + Scopes NotificationScopes `json:"scopes"` + ReportAt Schedule `json:"report_at"` + LastReport sql.NullTime `json:"-"` +} + +type SubscribeRecordKeys struct { + Auth string `json:"auth"` + P256dh string `json:"p256dh"` +} + +var ( + _ sql.Scanner = (*SubscribeRecordKeys)(nil) + _ driver.Valuer = (*SubscribeRecordKeys)(nil) +) + +func (sk *SubscribeRecordKeys) Scan(src any) error { + var data []byte + switch v := src.(type) { + case []byte: + data = v + case string: + data = ([]byte)(v) + default: + return errors.New("Source is not a string") + } + return json.Unmarshal(data, sk) +} + +func (sk SubscribeRecordKeys) Value() (driver.Value, error) { + return json.Marshal(sk) +} + +type NotificationScopes struct { + Disabled bool `json:"disabled"` + Enabled bool `json:"enabled"` + SyncBegin bool `json:"syncbegin"` + SyncDone bool `json:"syncdone"` + Updates bool `json:"updates"` + DailyReport bool `json:"dailyreport"` +} + +var ( + _ sql.Scanner = (*NotificationScopes)(nil) + _ driver.Valuer = (*NotificationScopes)(nil) +) + +//// !!WARN: Do not edit nsFlag's order //// + +const ( + nsFlagDisabled = 1 << iota + nsFlagEnabled + nsFlagSyncDone + nsFlagUpdates + nsFlagDailyReport + nsFlagSyncBegin +) + +func (ns NotificationScopes) ToInt64() (v int64) { + if ns.Disabled { + v |= nsFlagDisabled + } + if ns.Enabled { + v |= nsFlagEnabled + } + if ns.SyncBegin { + v |= nsFlagSyncBegin + } + if ns.SyncDone { + v |= nsFlagSyncDone + } + if ns.Updates { + v |= nsFlagUpdates + } + if ns.DailyReport { + v |= nsFlagDailyReport + } + return +} + +func (ns *NotificationScopes) FromInt64(v int64) { + ns.Disabled = v&nsFlagDisabled != 0 + ns.Enabled = v&nsFlagEnabled != 0 + ns.SyncBegin = v&nsFlagSyncBegin != 0 + ns.SyncDone = v&nsFlagSyncDone != 0 + ns.Updates = v&nsFlagUpdates != 0 + ns.DailyReport = v&nsFlagDailyReport != 0 +} + +func (ns *NotificationScopes) Scan(src any) error { + v, ok := src.(int64) + if !ok { + return errors.New("Source is not a integer") + } + ns.FromInt64(v) + return nil +} + +func (ns NotificationScopes) Value() (driver.Value, error) { + return ns.ToInt64(), nil +} + +func (ns *NotificationScopes) FromStrings(scopes []string) { + for _, s := range scopes { + switch s { + case "disabled": + ns.Disabled = true + case "enabled": + ns.Enabled = true + case "syncbegin": + ns.SyncBegin = true + case "syncdone": + ns.SyncDone = true + case "updates": + ns.Updates = true + case "dailyreport": + ns.DailyReport = true + } + } +} + +func (ns *NotificationScopes) UnmarshalJSON(data []byte) (err error) { + { + type T NotificationScopes + if err = json.Unmarshal(data, (*T)(ns)); err == nil { + return + } + } + var v []string + if err = json.Unmarshal(data, &v); err != nil { + return + } + ns.FromStrings(v) + return +} + +type Schedule struct { + Hour int + Minute int +} + +var ( + _ sql.Scanner = (*Schedule)(nil) + _ driver.Valuer = (*Schedule)(nil) +) + +func (s Schedule) String() string { + return fmt.Sprintf("%02d:%02d", s.Hour, s.Minute) +} + +func (s *Schedule) UnmarshalText(buf []byte) (err error) { + if _, err = fmt.Sscanf((string)(buf), "%02d:%02d", &s.Hour, &s.Minute); err != nil { + return + } + if s.Hour < 0 || s.Hour >= 24 { + return fmt.Errorf("Hour %d out of range [0, 24)", s.Hour) + } + if s.Minute < 0 || s.Minute >= 60 { + return fmt.Errorf("Minute %d out of range [0, 60)", s.Minute) + } + return +} + +func (s *Schedule) UnmarshalJSON(buf []byte) (err error) { + var v string + if err = json.Unmarshal(buf, &v); err != nil { + return + } + return s.UnmarshalText(([]byte)(v)) +} + +func (s *Schedule) MarshalJSON() (buf []byte, err error) { + return json.Marshal(s.String()) +} + +func (s *Schedule) Scan(src any) error { + var v []byte + switch w := src.(type) { + case []byte: + v = w + case string: + v = ([]byte)(w) + default: + return fmt.Errorf("Unexpected type %T", src) + } + return s.UnmarshalText(v) +} + +func (s Schedule) Value() (driver.Value, error) { + return s.String(), nil +} + +func (s Schedule) ReadySince(last, now time.Time) bool { + if last.IsZero() { + last = now.Add(-time.Hour*24 + 1) + } + mustAfter := last.Add(time.Hour * 12) + if now.Before(mustAfter) { + return false + } + if !now.Before(last.Add(time.Hour * 24)) { + return true + } + hour, min := now.Hour(), now.Minute() + if s.Hour < hour && s.Hour+3 > hour || s.Hour == hour && s.Minute <= min { + return true + } + return false +} + +type EmailSubscriptionRecord struct { + User string `json:"user"` + Addr string `json:"addr"` + Scopes NotificationScopes `json:"scopes"` + Enabled bool `json:"enabled"` +} + +type WebhookRecord struct { + User string `json:"user"` + Id uuid.UUID `json:"id"` + Name string `json:"name"` + EndPoint string `json:"endpoint"` + Auth *string `json:"auth,omitempty"` + AuthHash string `json:"authHash,omitempty"` + Scopes NotificationScopes `json:"scopes"` + Enabled bool `json:"enabled"` +} + +func (rec *WebhookRecord) CovertAuthHash() { + if rec.Auth == nil || *rec.Auth == "" { + rec.AuthHash = "" + } else { + rec.AuthHash = "sha256:" + utils.AsSha256Hex(*rec.Auth) + } + rec.Auth = nil +} diff --git a/api/token.go b/api/token.go new file mode 100644 index 00000000..3f2081f3 --- /dev/null +++ b/api/token.go @@ -0,0 +1,35 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "net/url" +) + +type TokenVerifier interface { + VerifyAuthToken(clientId string, token string) (tokenId string, userId string, err error) + VerifyAPIToken(clientId string, token string, path string, query url.Values) (userId string, err error) +} + +type TokenManager interface { + TokenVerifier + GenerateAuthToken(clientId string, userId string) (token string, err error) + GenerateAPIToken(clientId string, userId string, path string, query map[string]string) (token string, err error) +} diff --git a/api/user.go b/api/user.go new file mode 100644 index 00000000..85ec6fd6 --- /dev/null +++ b/api/user.go @@ -0,0 +1,62 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +type UserManager interface { + GetUsers() []*User + GetUser(id string) *User + AddUser(*User) error + RemoveUser(id string) error + UpdateUserPassword(username string, password string) error + UpdateUserPermissions(username string, permissions PermissionFlag) error + + VerifyUserPassword(userId string, comparator func(password string) bool) error +} + +type PermissionFlag uint32 + +const ( + // BasicPerm includes majority client side actions, such as login, which do not have a significant impact on the server + BasicPerm PermissionFlag = 1 << iota + // SubscribePerm allows the user to subscribe server status & other posts + SubscribePerm + // LogPerm allows the user to view non-debug logs & download access logs + LogPerm + // DebugPerm allows the user to access debug settings and download debug logs + DebugPerm + // FullConfigPerm allows the user to access all config values + FullConfigPerm + // ClusterPerm allows the user to configure clusters' settings & stop/start clusters + ClusterPerm + // StoragePerm allows the user to configure storages' settings & decides to manually start storages' sync process + StoragePerm + // BypassLimitPerm allows the user to ignore API access limit + BypassLimitPerm + // RootPerm user can add/remove users, reset their password, and change their permission flags + RootPerm PermissionFlag = 1 << 31 + + AllPerm = ^(PermissionFlag)(0) +) + +type User struct { + Username string + Password string // as sha256 + Permissions PermissionFlag +} diff --git a/api/v0/api.go b/api/v0/api.go index 7d52ddf9..1cc4ee42 100644 --- a/api/v0/api.go +++ b/api/v0/api.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package v0 import ( "compress/gzip" @@ -36,10 +36,11 @@ import ( "sync/atomic" "time" - "runtime/pprof" - // "github.com/gorilla/websocket" "github.com/google/uuid" + "github.com/gorilla/schema" + "runtime/pprof" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/limited" @@ -58,32 +59,108 @@ func apiGetClientId(req *http.Request) (id string) { return req.Context().Value(clientIdKey).(string) } -func (cr *Cluster) cliIdHandle(next http.Handler) http.Handler { - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - var id string - if cid, _ := req.Cookie(clientIdCookieName); cid != nil { - id = cid.Value - } else { - var err error - id, err = utils.GenRandB64(16) - if err != nil { - http.Error(rw, "cannot generate random number", http.StatusInternalServerError) - return - } - http.SetCookie(rw, &http.Cookie{ - Name: clientIdCookieName, - Value: id, - Expires: time.Now().Add(time.Hour * 24 * 365 * 16), - Secure: true, - HttpOnly: true, - }) - } - req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) - next.ServeHTTP(rw, req) +type Handler struct { + handler *utils.HttpMiddleWareHandler + router *http.ServeMux + userManager api.UserManager + tokenManager api.TokenManager + subManager api.SubscriptionManager +} + +var _ http.Handler = (*Handler)(nil) + +func NewHandler(verifier TokenVerifier, subManager api.SubscriptionManager) *Handler { + mux := http.NewServeMux() + h := &Handler{ + router: mux, + handler: utils.NewHttpMiddleWareHandler(mux), + verifier: verifier, + subManager: subManager, + } + h.buildRoute() + h.handler.Use(cliIdMiddleWare) + h.handler.Use(h.authMiddleWare) + return h +} + +func (h *Handler) Handler() *utils.HttpMiddleWareHandler { + return h.handler +} + +func (h *Handler) buildRoute() { + mux := h.router + + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { + writeJson(rw, http.StatusNotFound, Map{ + "error": "404 not found", + "path": req.URL.Path, + }) }) + + mux.HandleFunc("/ping", h.routePing) + mux.HandleFunc("/status", h.routeStatus) + mux.Handle("/stat/", http.StripPrefix("/stat/", (http.HandlerFunc)(h.routeStat))) + + mux.HandleFunc("/challenge", h.routeChallenge) + mux.HandleFunc("/login", h.routeLogin) + mux.Handle("/requestToken", authHandleFunc(h.routeRequestToken)) + mux.Handle("/logout", authHandleFunc(h.routeLogout)) + + mux.HandleFunc("/log.io", h.routeLogIO) + mux.Handle("/pprof", authHandleFunc(h.routePprof)) + mux.HandleFunc("/subscribeKey", h.routeSubscribeKey) + mux.Handle("/subscribe", authHandle(&utils.HttpMethodHandler{ + Get: h.routeSubscribeGET, + Post: h.routeSubscribePOST, + Delete: h.routeSubscribeDELETE, + })) + mux.Handle("/subscribe_email", authHandle(&utils.HttpMethodHandler{ + Get: h.routeSubscribeEmailGET, + Post: h.routeSubscribeEmailPOST, + Patch: h.routeSubscribeEmailPATCH, + Delete: h.routeSubscribeEmailDELETE, + })) + mux.Handle("/webhook", authHandle(&utils.HttpMethodHandler{ + Get: h.routeWebhookGET, + Post: h.routeWebhookPOST, + Patch: h.routeWebhookPATCH, + Delete: h.routeWebhookDELETE, + })) + + mux.Handle("/log_files", authHandleFunc(h.routeLogFiles)) + mux.Handle("/log_file/", authHandle(http.StripPrefix("/log_file/", (http.HandlerFunc)(h.routeLogFile)))) + + mux.Handle("/configure/cluster", authHandleFunc(h.routeConfigureCluster)) +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + h.handler.ServeHTTP(rw, req) +} + +func cliIdMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + var id string + if cid, _ := req.Cookie(clientIdCookieName); cid != nil { + id = cid.Value + } else { + var err error + id, err = utils.GenRandB64(16) + if err != nil { + http.Error(rw, "cannot generate random number", http.StatusInternalServerError) + return + } + http.SetCookie(rw, &http.Cookie{ + Name: clientIdCookieName, + Value: id, + Expires: time.Now().Add(time.Hour * 24 * 365 * 16), + Secure: true, + HttpOnly: true, + }) + } + req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) + next.ServeHTTP(rw, req) } -func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, next http.Handler) { +func (h *Handler) authMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { cli := apiGetClientId(req) ctx := req.Context() @@ -96,7 +173,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex if req.Method == http.MethodGet { if tk := req.URL.Query().Get("_t"); tk != "" { path := GetRequestRealPath(req) - if id, uid, err = cr.verifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { + if id, uid, err = h.verifier.verifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAPI) } } @@ -108,7 +185,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex if err == nil { err = ErrUnsupportAuthType } - } else if id, uid, err = cr.verifyAuthToken(cli, tk); err != nil { + } else if id, uid, err = h.verifier.VerifyAuthToken(cli, tk); err != nil { id = "" } else { ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAuth) @@ -122,7 +199,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex next.ServeHTTP(rw, req) } -func (cr *Cluster) apiAuthHandle(next http.Handler) http.Handler { +func authHandle(next http.Handler) http.Handler { return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { if req.Context().Value(tokenTypeKey) == nil { writeJson(rw, http.StatusUnauthorized, Map{ @@ -134,69 +211,13 @@ func (cr *Cluster) apiAuthHandle(next http.Handler) http.Handler { }) } -func (cr *Cluster) apiAuthHandleFunc(next http.HandlerFunc) http.Handler { - return cr.apiAuthHandle(next) -} - -func (cr *Cluster) initAPIv0() http.Handler { - mux := http.NewServeMux() - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - writeJson(rw, http.StatusNotFound, Map{ - "error": "404 not found", - "path": req.URL.Path, - }) - }) - - mux.HandleFunc("/ping", cr.apiV1Ping) - mux.HandleFunc("/status", cr.apiV0Status) - mux.Handle("/stat/", http.StripPrefix("/stat/", (http.HandlerFunc)(cr.apiV0Stat))) - - mux.HandleFunc("/challenge", cr.apiV1Challenge) - mux.HandleFunc("/login", cr.apiV0Login) - mux.Handle("/requestToken", cr.apiAuthHandleFunc(cr.apiV0RequestToken)) - mux.Handle("/logout", cr.apiAuthHandleFunc(cr.apiV1Logout)) - - mux.HandleFunc("/log.io", cr.apiV1LogIO) - mux.Handle("/pprof", cr.apiAuthHandleFunc(cr.apiV1Pprof)) - mux.HandleFunc("/subscribeKey", cr.apiV0SubscribeKey) - mux.Handle("/subscribe", cr.apiAuthHandleFunc(cr.apiV0Subscribe)) - mux.Handle("/subscribe_email", cr.apiAuthHandleFunc(cr.apiV0SubscribeEmail)) - mux.Handle("/webhook", cr.apiAuthHandleFunc(cr.apiV0Webhook)) - - mux.Handle("/log_files", cr.apiAuthHandleFunc(cr.apiV0LogFiles)) - mux.Handle("/log_file/", cr.apiAuthHandle(http.StripPrefix("/log_file/", (http.HandlerFunc)(cr.apiV0LogFile)))) - - next := cr.apiRateLimiter.WrapHandler(mux) - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - cr.authMiddleware(rw, req, next) - }) -} - -func (cr *Cluster) initAPIv1() http.Handler { - mux := http.NewServeMux() - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - writeJson(rw, http.StatusNotFound, Map{ - "error": "404 not found", - "path": req.URL.Path, - }) - }) - - mux.HandleFunc("/ping", cr.apiV1Ping) - - mux.HandleFunc("/challenge", cr.apiV1Challenge) - mux.Handle("/logout", cr.apiAuthHandleFunc(cr.apiV1Logout)) - - mux.HandleFunc("/log.io", cr.apiV1LogIO) - mux.Handle("/pprof", cr.apiAuthHandleFunc(cr.apiV1Pprof)) - - next := cr.apiRateLimiter.WrapHandler(mux) - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - cr.authMiddleware(rw, req, next) - }) +func authHandleFunc(next http.HandlerFunc) http.Handler { + return authHandle(next) } -func (cr *Cluster) apiV1Ping(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routePing(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -208,8 +229,9 @@ func (cr *Cluster) apiV1Ping(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0Status(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeStatus(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -245,8 +267,9 @@ func (cr *Cluster) apiV0Status(rw http.ResponseWriter, req *http.Request) { writeJson(rw, http.StatusOK, &status) } -func (cr *Cluster) apiV0Stat(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeStat(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -265,14 +288,15 @@ func (cr *Cluster) apiV0Stat(rw http.ResponseWriter, req *http.Request) { writeJson(rw, http.StatusOK, (json.RawMessage)(data)) } -func (cr *Cluster) apiV1Challenge(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (h *Handler) routeChallenge(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } cli := apiGetClientId(req) query := req.URL.Query() action := query.Get("action") - token, err := cr.generateChallengeToken(cli, action) + token, err := h.generateChallengeToken(cli, action) if err != nil { writeJson(rw, http.StatusInternalServerError, Map{ "error": "Cannot generate token", @@ -285,8 +309,9 @@ func (cr *Cluster) apiV1Challenge(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (h *Handler) routeLogin(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } if !config.Dashboard.Enable { @@ -297,44 +322,25 @@ func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { } cli := apiGetClientId(req) - type T = struct { - User string `json:"username"` - Challenge string `json:"challenge"` - Signature string `json:"signature"` - } - data, ok := parseRequestBody(rw, req, func(rw http.ResponseWriter, req *http.Request, ct string, data *T) error { - switch ct { - case "application/x-www-form-urlencoded": - data.User = req.PostFormValue("username") - data.Challenge = req.PostFormValue("challenge") - data.Signature = req.PostFormValue("signature") - return nil - default: - return errUnknownContent - } - }) - if !ok { - return + var data struct { + User string `json:"username" schema:"username"` + Challenge string `json:"challenge" schema:"challenge"` + Signature string `json:"signature" schema:"signature"` } - - expectUsername, expectPassword := config.Dashboard.Username, config.Dashboard.Password - if expectUsername == "" || expectPassword == "" { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "The username or password was not set on the server", - }) + if !parseRequestBody(rw, req, &data) { return } - if err := cr.verifyChallengeToken(cli, "login", data.Challenge); err != nil { + if err := h.verifier.VerifyChallengeToken(cli, "login", data.Challenge); err != nil { writeJson(rw, http.StatusUnauthorized, Map{ "error": "Invalid challenge", }) return } - expectPassword = utils.AsSha256Hex(expectPassword) - expectSignature := utils.HMACSha256Hex(expectPassword, data.Challenge) - if subtle.ConstantTimeCompare(([]byte)(expectUsername), ([]byte)(data.User)) == 0 || - subtle.ConstantTimeCompare(([]byte)(expectSignature), ([]byte)(data.Signature)) == 0 { + if err := h.verifier.VerifyUserPassword(data.User, func(password string) bool { + expectSignature := utils.HMACSha256HexBytes(password, data.Challenge) + return subtle.ConstantTimeCompare(expectSignature, ([]byte)(data.Signature)) == 0 + }); err != nil { writeJson(rw, http.StatusUnauthorized, Map{ "error": "The username or password is incorrect", }) @@ -353,8 +359,9 @@ func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (cr *Cluster) routeRequestToken(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } defer req.Body.Close() @@ -369,11 +376,8 @@ func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) Path string `json:"path"` Query map[string]string `json:"query,omitempty"` } - if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "cannot decode payload in json format", - "message": err.Error(), - }) + if !parseRequestBody(rw, req, &payload) { + return } log.Debugf("payload: %#v", payload) if payload.Path == "" || payload.Path[0] != '/' { @@ -398,8 +402,9 @@ func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) }) } -func (cr *Cluster) apiV1Logout(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (cr *Cluster) routeLogout(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } limited.SetSkipRateLimit(req) @@ -408,7 +413,7 @@ func (cr *Cluster) apiV1Logout(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV1LogIO(rw http.ResponseWriter, req *http.Request) { +func (cr *Cluster) routeLogIO(rw http.ResponseWriter, req *http.Request) { addr, _ := req.Context().Value(RealAddrCtxKey).(string) conn, err := cr.wsUpgrader.Upgrade(rw, req, nil) @@ -624,8 +629,9 @@ func (cr *Cluster) apiV1LogIO(rw http.ResponseWriter, req *http.Request) { } } -func (cr *Cluster) apiV1Pprof(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routePprof(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } query := req.URL.Query() @@ -661,244 +667,8 @@ func (cr *Cluster) apiV1Pprof(rw http.ResponseWriter, req *http.Request) { p.WriteTo(rw, debug) } -func (cr *Cluster) apiV0SubscribeKey(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { - return - } - key := cr.webpushKeyB64 - etag := `"` + utils.AsSha256(key) + `"` - rw.Header().Set("ETag", etag) - if cachedTag := req.Header.Get("If-None-Match"); cachedTag == etag { - rw.WriteHeader(http.StatusNotModified) - return - } - writeJson(rw, http.StatusOK, Map{ - "publicKey": key, - }) -} - -func (cr *Cluster) apiV0Subscribe(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodDelete) { - return - } - cliId := apiGetClientId(req) - user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0SubscribeGET(rw, req, user, cliId) - case http.MethodPost: - cr.apiV0SubscribePOST(rw, req, user, cliId) - case http.MethodDelete: - cr.apiV0SubscribeDELETE(rw, req, user, cliId) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0SubscribeGET(rw http.ResponseWriter, req *http.Request, user string, client string) { - record, err := cr.database.GetSubscribe(user, client) - if err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, Map{ - "scopes": record.Scopes, - "reportAt": record.ReportAt, - }) -} - -func (cr *Cluster) apiV0SubscribePOST(rw http.ResponseWriter, req *http.Request, user string, client string) { - data, ok := parseRequestBody[database.SubscribeRecord](rw, req, nil) - if !ok { - return - } - data.User = user - data.Client = client - if err := cr.database.SetSubscribe(data); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Database update failed", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeDELETE(rw http.ResponseWriter, req *http.Request, user string, client string) { - if err := cr.database.RemoveSubscribe(user, client); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeEmail(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete) { - return - } - user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0SubscribeEmailGET(rw, req, user) - case http.MethodPost: - cr.apiV0SubscribeEmailPOST(rw, req, user) - case http.MethodPatch: - cr.apiV0SubscribeEmailPATCH(rw, req, user) - case http.MethodDelete: - cr.apiV0SubscribeEmailDELETE(rw, req, user) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0SubscribeEmailGET(rw http.ResponseWriter, req *http.Request, user string) { - if addr := req.URL.Query().Get("addr"); addr != "" { - record, err := cr.database.GetEmailSubscription(user, addr) - if err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, record) - return - } - records := make([]database.EmailSubscriptionRecord, 0, 4) - if err := cr.database.ForEachUsersEmailSubscription(user, func(rec *database.EmailSubscriptionRecord) error { - records = append(records, *rec) - return nil - }); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, records) -} - -func (cr *Cluster) apiV0SubscribeEmailPOST(rw http.ResponseWriter, req *http.Request, user string) { - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { - return - } - - data.User = user - if err := cr.database.AddEmailSubscription(data); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Database update failed", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusCreated) -} - -func (cr *Cluster) apiV0SubscribeEmailPATCH(rw http.ResponseWriter, req *http.Request, user string) { - addr := req.URL.Query().Get("addr") - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { - return - } - data.User = user - data.Addr = addr - if err := cr.database.UpdateEmailSubscription(data); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeEmailDELETE(rw http.ResponseWriter, req *http.Request, user string) { - addr := req.URL.Query().Get("addr") - if err := cr.database.RemoveEmailSubscription(user, addr); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0Webhook(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete) { - return - } +func (cr *Cluster) routeWebhookGET(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0WebhookGET(rw, req, user) - case http.MethodPost: - cr.apiV0WebhookPOST(rw, req, user) - case http.MethodPatch: - cr.apiV0WebhookPATCH(rw, req, user) - case http.MethodDelete: - cr.apiV0WebhookDELETE(rw, req, user) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0WebhookGET(rw http.ResponseWriter, req *http.Request, user string) { if sid := req.URL.Query().Get("id"); sid != "" { id, err := uuid.Parse(sid) if err != nil { @@ -939,9 +709,10 @@ func (cr *Cluster) apiV0WebhookGET(rw http.ResponseWriter, req *http.Request, us writeJson(rw, http.StatusOK, records) } -func (cr *Cluster) apiV0WebhookPOST(rw http.ResponseWriter, req *http.Request, user string) { - data, ok := parseRequestBody[database.WebhookRecord](rw, req, nil) - if !ok { +func (cr *Cluster) routeWebhookPOST(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + var data database.WebhookRecord + if !parseRequestBody(rw, req, &data) { return } @@ -956,10 +727,11 @@ func (cr *Cluster) apiV0WebhookPOST(rw http.ResponseWriter, req *http.Request, u rw.WriteHeader(http.StatusCreated) } -func (cr *Cluster) apiV0WebhookPATCH(rw http.ResponseWriter, req *http.Request, user string) { +func (cr *Cluster) routeWebhookPATCH(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) id := req.URL.Query().Get("id") - data, ok := parseRequestBody[database.WebhookRecord](rw, req, nil) - if !ok { + var data database.WebhookRecord + if !parseRequestBody(rw, req, &data) { return } data.User = user @@ -987,7 +759,8 @@ func (cr *Cluster) apiV0WebhookPATCH(rw http.ResponseWriter, req *http.Request, rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV0WebhookDELETE(rw http.ResponseWriter, req *http.Request, user string) { +func (cr *Cluster) routeWebhookDELETE(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) id, err := uuid.Parse(req.URL.Query().Get("id")) if err != nil { writeJson(rw, http.StatusBadRequest, Map{ @@ -1012,8 +785,9 @@ func (cr *Cluster) apiV0WebhookDELETE(rw http.ResponseWriter, req *http.Request, rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV0LogFiles(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeLogFiles(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } files := log.ListLogs() @@ -1035,8 +809,9 @@ func (cr *Cluster) apiV0LogFiles(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0LogFile(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodHead) { +func (cr *Cluster) routeLogFile(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet && req.Method != http.MethodHead { + errorMethodNotAllowed(rw, req, http.MethodGet+", "+http.MethodHead) return } query := req.URL.Query() @@ -1078,11 +853,11 @@ func (cr *Cluster) apiV0LogFile(rw http.ResponseWriter, req *http.Request) { } rw.Header().Set("Content-Type", "application/octet-stream") rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name+".encrypted")) - cr.apiV0LogFileEncrypted(rw, req, fd, !isGzip) + cr.routeLogFileEncrypted(rw, req, fd, !isGzip) } } -func (cr *Cluster) apiV0LogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { +func (cr *Cluster) routeLogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { rw.WriteHeader(http.StatusOK) if req.Method == http.MethodHead { return @@ -1112,10 +887,9 @@ func (cr *Cluster) apiV0LogFileEncrypted(rw http.ResponseWriter, req *http.Reque type Map = map[string]any var errUnknownContent = errors.New("unknown content-type") +var formDecoder = schema.NewDecoder() -type requestBodyParser[T any] func(rw http.ResponseWriter, req *http.Request, contentType string, data *T) error - -func parseRequestBody[T any](rw http.ResponseWriter, req *http.Request, fallback requestBodyParser[T]) (data T, parsed bool) { +func parseRequestBody(rw http.ResponseWriter, req *http.Request, ptr any) (parsed bool) { contentType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) if err != nil { writeJson(rw, http.StatusBadRequest, Map{ @@ -1127,26 +901,31 @@ func parseRequestBody[T any](rw http.ResponseWriter, req *http.Request, fallback } switch contentType { case "application/json": - if err := json.NewDecoder(req.Body).Decode(&data); err != nil { + if err := json.NewDecoder(req.Body).Decode(ptr); err != nil { writeJson(rw, http.StatusBadRequest, Map{ "error": "Cannot decode request body", "message": err.Error(), }) return } - return data, true - default: - if fallback != nil { - if err := fallback(rw, req, contentType, &data); err == nil { - return data, true - } else if err != errUnknownContent { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "Cannot decode request body", - "message": err.Error(), - }) - return - } + return true + case "application/x-www-form-urlencoded": + if err := req.ParseForm(); err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Cannot decode request body", + "message": err.Error(), + }) + return } + if err := formDecoder.Decode(ptr, req.PostForm); err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Cannot decode request body", + "message": err.Error(), + }) + return + } + return true + default: writeJson(rw, http.StatusBadRequest, Map{ "error": "Unexpected Content-Type", "content-type": contentType, @@ -1168,18 +947,8 @@ func writeJson(rw http.ResponseWriter, code int, data any) (err error) { return } -func checkRequestMethodOrRejectWithJson(rw http.ResponseWriter, req *http.Request, allows ...string) (rejected bool) { - m := req.Method - for _, a := range allows { - if m == a { - return false - } - } - rw.Header().Set("Allow", strings.Join(allows, ", ")) - writeJson(rw, http.StatusMethodNotAllowed, Map{ - "error": "405 method not allowed", - "method": m, - "allow": allows, - }) +func errorMethodNotAllowed(rw http.ResponseWriter, req *http.Request, allow string) { + rw.Header().Set("Allow", allow) + rw.WriteHeader(http.StatusMethodNotAllowed) return true } diff --git a/api/v0/api_token.go b/api/v0/api_token.go index 0568e452..bd4ba177 100644 --- a/api/v0/api_token.go +++ b/api/v0/api_token.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package v0 import ( "errors" @@ -232,7 +232,7 @@ func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, qu return tokenStr, nil } -func (cr *Cluster) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { +func (h *Handler) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { var claims apiTokenClaims _, err = jwt.ParseWithClaims( token, diff --git a/api/v0/configure_cluster.go b/api/v0/configure_cluster.go new file mode 100644 index 00000000..7bc143be --- /dev/null +++ b/api/v0/configure_cluster.go @@ -0,0 +1,24 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +func (h *Handler) apiConfigureCluster() { + // +} diff --git a/api/v0/subscription.go b/api/v0/subscription.go new file mode 100644 index 00000000..0a5f4709 --- /dev/null +++ b/api/v0/subscription.go @@ -0,0 +1,198 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +import ( + "net/http" +) + +func (h *Handler) routeSubscribeKey(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) + return + } + key := h.subManager.GetWebPushKey() + etag := `"` + utils.AsSha256(key) + `"` + rw.Header().Set("ETag", etag) + if cachedTag := req.Header.Get("If-None-Match"); cachedTag == etag { + rw.WriteHeader(http.StatusNotModified) + return + } + writeJson(rw, http.StatusOK, Map{ + "publicKey": key, + }) +} + +func (h *Handler) routeSubscribeGET(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + record, err := h.subManager.GetSubscribe(user, client) + if err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, Map{ + "scopes": record.Scopes, + "reportAt": record.ReportAt, + }) +} + +func (h *Handler) routeSubscribePOST(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + data, ok := parseRequestBody[database.SubscribeRecord](rw, req, nil) + if !ok { + return + } + data.User = user + data.Client = client + if err := h.subManager.SetSubscribe(data); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Database update failed", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeDELETE(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + if err := h.subManager.RemoveSubscribe(user, client); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeEmailGET(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + if addr := req.URL.Query().Get("addr"); addr != "" { + record, err := h.subManager.GetEmailSubscription(user, addr) + if err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, record) + return + } + records := make([]database.EmailSubscriptionRecord, 0, 4) + if err := h.subManager.ForEachUsersEmailSubscription(user, func(rec *database.EmailSubscriptionRecord) error { + records = append(records, *rec) + return nil + }); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, records) +} + +func (h *Handler) routeSubscribeEmailPOST(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) + if !ok { + return + } + + data.User = user + if err := h.subManager.AddEmailSubscription(data); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Database update failed", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusCreated) +} + +func (h *Handler) routeSubscribeEmailPATCH(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + addr := req.URL.Query().Get("addr") + data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) + if !ok { + return + } + data.User = user + data.Addr = addr + if err := h.subManager.UpdateEmailSubscription(data); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeEmailDELETE(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + addr := req.URL.Query().Get("addr") + if err := h.subManager.RemoveEmailSubscription(user, addr); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} diff --git a/cluster/config.go b/cluster/config.go index 2eebc617..c8fac95d 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -29,8 +29,12 @@ import ( "fmt" "net/http" "net/url" + "strconv" "time" + "github.com/hamba/avro/v2" + "github.com/klauspost/compress/zstd" + "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -259,3 +263,63 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } return } + +type FileInfo struct { + Path string `json:"path" avro:"path"` + Hash string `json:"hash" avro:"hash"` + Size int64 `json:"size" avro:"size"` + Mtime int64 `json:"mtime" avro:"mtime"` +} + +// from +var fileListSchema = avro.MustParse(`{ + "type": "array", + "items": { + "type": "record", + "name": "fileinfo", + "fields": [ + {"name": "path", "type": "string"}, + {"name": "hash", "type": "string"}, + {"name": "size", "type": "long"}, + {"name": "mtime", "type": "long"} + ] + } +}`) + +func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { + var query url.Values + if lastMod > 0 { + query = url.Values{ + "lastModified": {strconv.FormatInt(lastMod, 10)}, + } + } + req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) + if err != nil { + return + } + res, err := cr.cachedCli.Do(req) + if err != nil { + return + } + defer res.Body.Close() + switch res.StatusCode { + case http.StatusOK: + // + case http.StatusNoContent, http.StatusNotModified: + return + default: + err = utils.NewHTTPStatusErrorFromResponse(res) + return + } + log.Debug("Parsing filelist body ...") + zr, err := zstd.NewReader(res.Body) + if err != nil { + return + } + defer zr.Close() + if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { + return + } + log.Debugf("Filelist parsed, length = %d", len(files)) + return +} diff --git a/cluster/http.go b/cluster/http.go index d83245fa..20045c71 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -20,6 +20,51 @@ package cluster import ( + "context" + "io" "net/http" "net/url" + "path" + + "github.com/LiterMC/go-openbmclapi/internal/build" ) + +func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { + return cr.makeReqWithBody(ctx, method, relpath, query, nil) +} + +func (cr *Cluster) makeReqWithBody( + ctx context.Context, + method string, relpath string, + query url.Values, body io.Reader, +) (req *http.Request, err error) { + var u *url.URL + if u, err = url.Parse(cr.opts.Prefix); err != nil { + return + } + u.Path = path.Join(u.Path, relpath) + if query != nil { + u.RawQuery = query.Encode() + } + target := u.String() + + req, err = http.NewRequestWithContext(ctx, method, target, body) + if err != nil { + return + } + req.Header.Set("User-Agent", build.ClusterUserAgent) + return +} + +func (cr *Cluster) makeReqWithAuth(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { + req, err = cr.makeReq(ctx, method, relpath, query) + if err != nil { + return + } + token, err := cr.GetAuthToken(ctx) + if err != nil { + return + } + req.Header.Set("Authorization", "Bearer "+token) + return +} diff --git a/cluster/keepalive.go b/cluster/keepalive.go index 478c1763..2dee82a9 100644 --- a/cluster/keepalive.go +++ b/cluster/keepalive.go @@ -22,6 +22,9 @@ package cluster import ( "context" "time" + + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/utils" ) type KeepAliveRes int @@ -49,10 +52,10 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { }) if e := cr.stats.Save(cr.dataDir); e != nil { - log.Errorf(Tr("error.cluster.stat.save.failed"), e) + log.TrErrorf("error.cluster.stat.save.failed", e) } if err != nil { - log.Errorf(Tr("error.cluster.keepalive.send.failed"), err) + log.TrErrorf("error.cluster.keepalive.send.failed", err) return KeepAliveFailed } var data []any @@ -65,21 +68,19 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { if ero := data[0]; len(data) <= 1 || ero != nil { if ero, ok := ero.(map[string]any); ok { if msg, ok := ero["message"].(string); ok { - log.Errorf(Tr("error.cluster.keepalive.failed"), msg) + log.TrErrorf("error.cluster.keepalive.failed", msg) if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { hash := hashMismatch[1] log.Warnf("Detected hash mismatch error, removing bad file %s", hash) - for _, s := range cr.storages { - go s.Remove(hash) - } + cr.storageManager.RemoveForAll(hash) } return KeepAliveFailed } } - log.Errorf(Tr("error.cluster.keepalive.failed"), ero) + log.TrErrorf("error.cluster.keepalive.failed", ero) return KeepAliveFailed } - log.Infof(Tr("info.cluster.keepalive.success"), ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) + log.TrInfof("info.cluster.keepalive.success", ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) cr.hits.Add(-hits2) cr.hbts.Add(-hbts2) if data[1] == false { diff --git a/database/db.go b/database/db.go index 5429d9d7..2f42c46a 100644 --- a/database/db.go +++ b/database/db.go @@ -83,244 +83,3 @@ type FileRecord struct { Hash string Size int64 } - -type SubscribeRecord struct { - User string `json:"user"` - Client string `json:"client"` - EndPoint string `json:"endpoint"` - Keys SubscribeRecordKeys `json:"keys"` - Scopes NotificationScopes `json:"scopes"` - ReportAt Schedule `json:"report_at"` - LastReport sql.NullTime `json:"-"` -} - -type SubscribeRecordKeys struct { - Auth string `json:"auth"` - P256dh string `json:"p256dh"` -} - -var ( - _ sql.Scanner = (*SubscribeRecordKeys)(nil) - _ driver.Valuer = (*SubscribeRecordKeys)(nil) -) - -func (sk *SubscribeRecordKeys) Scan(src any) error { - var data []byte - switch v := src.(type) { - case []byte: - data = v - case string: - data = ([]byte)(v) - default: - return errors.New("Source is not a string") - } - return json.Unmarshal(data, sk) -} - -func (sk SubscribeRecordKeys) Value() (driver.Value, error) { - return json.Marshal(sk) -} - -type NotificationScopes struct { - Disabled bool `json:"disabled"` - Enabled bool `json:"enabled"` - SyncBegin bool `json:"syncbegin"` - SyncDone bool `json:"syncdone"` - Updates bool `json:"updates"` - DailyReport bool `json:"dailyreport"` -} - -var ( - _ sql.Scanner = (*NotificationScopes)(nil) - _ driver.Valuer = (*NotificationScopes)(nil) -) - -//// !!WARN: Do not edit nsFlag's order //// - -const ( - nsFlagDisabled = 1 << iota - nsFlagEnabled - nsFlagSyncDone - nsFlagUpdates - nsFlagDailyReport - nsFlagSyncBegin -) - -func (ns NotificationScopes) ToInt64() (v int64) { - if ns.Disabled { - v |= nsFlagDisabled - } - if ns.Enabled { - v |= nsFlagEnabled - } - if ns.SyncBegin { - v |= nsFlagSyncBegin - } - if ns.SyncDone { - v |= nsFlagSyncDone - } - if ns.Updates { - v |= nsFlagUpdates - } - if ns.DailyReport { - v |= nsFlagDailyReport - } - return -} - -func (ns *NotificationScopes) FromInt64(v int64) { - ns.Disabled = v&nsFlagDisabled != 0 - ns.Enabled = v&nsFlagEnabled != 0 - ns.SyncBegin = v&nsFlagSyncBegin != 0 - ns.SyncDone = v&nsFlagSyncDone != 0 - ns.Updates = v&nsFlagUpdates != 0 - ns.DailyReport = v&nsFlagDailyReport != 0 -} - -func (ns *NotificationScopes) Scan(src any) error { - v, ok := src.(int64) - if !ok { - return errors.New("Source is not a integer") - } - ns.FromInt64(v) - return nil -} - -func (ns NotificationScopes) Value() (driver.Value, error) { - return ns.ToInt64(), nil -} - -func (ns *NotificationScopes) FromStrings(scopes []string) { - for _, s := range scopes { - switch s { - case "disabled": - ns.Disabled = true - case "enabled": - ns.Enabled = true - case "syncbegin": - ns.SyncBegin = true - case "syncdone": - ns.SyncDone = true - case "updates": - ns.Updates = true - case "dailyreport": - ns.DailyReport = true - } - } -} - -func (ns *NotificationScopes) UnmarshalJSON(data []byte) (err error) { - { - type T NotificationScopes - if err = json.Unmarshal(data, (*T)(ns)); err == nil { - return - } - } - var v []string - if err = json.Unmarshal(data, &v); err != nil { - return - } - ns.FromStrings(v) - return -} - -type Schedule struct { - Hour int - Minute int -} - -var ( - _ sql.Scanner = (*Schedule)(nil) - _ driver.Valuer = (*Schedule)(nil) -) - -func (s Schedule) String() string { - return fmt.Sprintf("%02d:%02d", s.Hour, s.Minute) -} - -func (s *Schedule) UnmarshalText(buf []byte) (err error) { - if _, err = fmt.Sscanf((string)(buf), "%02d:%02d", &s.Hour, &s.Minute); err != nil { - return - } - if s.Hour < 0 || s.Hour >= 24 { - return fmt.Errorf("Hour %d out of range [0, 24)", s.Hour) - } - if s.Minute < 0 || s.Minute >= 60 { - return fmt.Errorf("Minute %d out of range [0, 60)", s.Minute) - } - return -} - -func (s *Schedule) UnmarshalJSON(buf []byte) (err error) { - var v string - if err = json.Unmarshal(buf, &v); err != nil { - return - } - return s.UnmarshalText(([]byte)(v)) -} - -func (s *Schedule) MarshalJSON() (buf []byte, err error) { - return json.Marshal(s.String()) -} - -func (s *Schedule) Scan(src any) error { - var v []byte - switch w := src.(type) { - case []byte: - v = w - case string: - v = ([]byte)(w) - default: - return fmt.Errorf("Unexpected type %T", src) - } - return s.UnmarshalText(v) -} - -func (s Schedule) Value() (driver.Value, error) { - return s.String(), nil -} - -func (s Schedule) ReadySince(last, now time.Time) bool { - if last.IsZero() { - last = now.Add(-time.Hour*24 + 1) - } - mustAfter := last.Add(time.Hour * 12) - if now.Before(mustAfter) { - return false - } - if !now.Before(last.Add(time.Hour * 24)) { - return true - } - hour, min := now.Hour(), now.Minute() - if s.Hour < hour && s.Hour+3 > hour || s.Hour == hour && s.Minute <= min { - return true - } - return false -} - -type EmailSubscriptionRecord struct { - User string `json:"user"` - Addr string `json:"addr"` - Scopes NotificationScopes `json:"scopes"` - Enabled bool `json:"enabled"` -} - -type WebhookRecord struct { - User string `json:"user"` - Id uuid.UUID `json:"id"` - Name string `json:"name"` - EndPoint string `json:"endpoint"` - Auth *string `json:"auth,omitempty"` - AuthHash string `json:"authHash,omitempty"` - Scopes NotificationScopes `json:"scopes"` - Enabled bool `json:"enabled"` -} - -func (rec *WebhookRecord) CovertAuthHash() { - if rec.Auth == nil || *rec.Auth == "" { - rec.AuthHash = "" - } else { - rec.AuthHash = "sha256:" + utils.AsSha256Hex(*rec.Auth) - } - rec.Auth = nil -} diff --git a/limited/api_rate.go b/limited/api_rate.go index 961501c7..33b7ea83 100644 --- a/limited/api_rate.go +++ b/limited/api_rate.go @@ -176,6 +176,8 @@ type APIRateMiddleWare struct { startAt time.Time } +var _ utils.MiddleWare = (*APIRateMiddleWare)(nil) + func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMiddleWare) { a = &APIRateMiddleWare{ loggedContextKey: loggedContextKey, @@ -184,9 +186,9 @@ func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMid cleanTicker: time.NewTicker(time.Minute), startAt: time.Now(), } - go func() { + go func(ticker *time.Ticker) { count := 0 - for range a.cleanTicker.C { + for range ticker.C { count++ ishour := count > 60 if ishour { @@ -195,12 +197,10 @@ func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMid a.clean(ishour) } log.Debugf("cleaner exited") - }() + }(a.cleanTicker) return } -var _ utils.MiddleWare = (*APIRateMiddleWare)(nil) - const ( RateLimitOverrideContextKey = "go-openbmclapi.limited.rate.api.override" RateLimitSkipContextKey = "go-openbmclapi.limited.rate.api.skip" diff --git a/sync.go b/sync.go index ea4ed602..0053653f 100644 --- a/sync.go +++ b/sync.go @@ -75,106 +75,6 @@ func (cr *Cluster) CachedFileSize(hash string) (size int64, ok bool) { return } -func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { - return cr.makeReqWithBody(ctx, method, relpath, query, nil) -} - -func (cr *Cluster) makeReqWithBody( - ctx context.Context, - method string, relpath string, - query url.Values, body io.Reader, -) (req *http.Request, err error) { - var u *url.URL - if u, err = url.Parse(cr.prefix); err != nil { - return - } - u.Path = path.Join(u.Path, relpath) - if query != nil { - u.RawQuery = query.Encode() - } - target := u.String() - - req, err = http.NewRequestWithContext(ctx, method, target, body) - if err != nil { - return - } - req.Header.Set("User-Agent", build.ClusterUserAgent) - return -} - -func (cr *Cluster) makeReqWithAuth(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { - req, err = cr.makeReq(ctx, method, relpath, query) - if err != nil { - return - } - token, err := cr.GetAuthToken(ctx) - if err != nil { - return - } - req.Header.Set("Authorization", "Bearer "+token) - return -} - -type FileInfo struct { - Path string `json:"path" avro:"path"` - Hash string `json:"hash" avro:"hash"` - Size int64 `json:"size" avro:"size"` - Mtime int64 `json:"mtime" avro:"mtime"` -} - -// from -var fileListSchema = avro.MustParse(`{ - "type": "array", - "items": { - "type": "record", - "name": "fileinfo", - "fields": [ - {"name": "path", "type": "string"}, - {"name": "hash", "type": "string"}, - {"name": "size", "type": "long"}, - {"name": "mtime", "type": "long"} - ] - } -}`) - -func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { - var query url.Values - if lastMod > 0 { - query = url.Values{ - "lastModified": {strconv.FormatInt(lastMod, 10)}, - } - } - req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) - if err != nil { - return - } - res, err := cr.cachedCli.Do(req) - if err != nil { - return - } - defer res.Body.Close() - switch res.StatusCode { - case http.StatusOK: - // - case http.StatusNoContent, http.StatusNotModified: - return - default: - err = utils.NewHTTPStatusErrorFromResponse(res) - return - } - log.Debug("Parsing filelist body ...") - zr, err := zstd.NewReader(res.Body) - if err != nil { - return - } - defer zr.Close() - if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { - return - } - log.Debugf("Filelist parsed, length = %d", len(files)) - return -} - type syncStats struct { slots *limited.BufSlots noOpen bool diff --git a/utils/crypto.go b/utils/crypto.go index 55c90c0b..40f70846 100644 --- a/utils/crypto.go +++ b/utils/crypto.go @@ -60,10 +60,16 @@ func AsSha256Hex(s string) string { } func HMACSha256Hex(key, data string) string { + return (string)(HMACSha256HexBytes(key, data)) +} + +func HMACSha256HexBytes(key, data string) []byte { m := hmac.New(sha256.New, ([]byte)(key)) m.Write(([]byte)(data)) buf := m.Sum(nil) - return hex.EncodeToString(buf[:]) + value := make([]byte, hex.EncodedLen(len(buf))) + hex.Encode(value, buf[:]) + return value } func GenRandB64(n int) (s string, err error) { diff --git a/utils/http.go b/utils/http.go index 65809812..f0187e26 100644 --- a/utils/http.go +++ b/utils/http.go @@ -36,8 +36,8 @@ import ( "sync/atomic" "time" - "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" ) type StatusResponseWriter struct { @@ -125,12 +125,22 @@ type HttpMiddleWareHandler struct { middles []MiddleWare } -func NewHttpMiddleWareHandler(final http.Handler) *HttpMiddleWareHandler { +var _ http.Handler = (*HttpMiddleWareHandler)(nil) + +func NewHttpMiddleWareHandler(final http.Handler, middles ...MiddleWare) *HttpMiddleWareHandler { return &HttpMiddleWareHandler{ - final: final, + final: final, + middles: middles, } } +// Handler returns the final http.Handler +func (m *HttpMiddleWareHandler) Handler() http.Handler { + return m.final +} + +// ServeHTTP implements http.Handler +// It will invoke the middlewares in order func (m *HttpMiddleWareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { i := 0 var getNext func() http.Handler @@ -158,16 +168,122 @@ func (m *HttpMiddleWareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Requ getNext().ServeHTTP(rw, req) } +// Use append MiddleWares to the middleware chain func (m *HttpMiddleWareHandler) Use(mids ...MiddleWare) { m.middles = append(m.middles, mids...) } +// UseFunc append MiddleWareFuncs to the middleware chain func (m *HttpMiddleWareHandler) UseFunc(fns ...MiddleWareFunc) { for _, fn := range fns { m.middles = append(m.middles, fn) } } +// HttpMethodHandler pass down http requests to different handler based on the request methods +// The HttpMethodHandler should not be modified after called ServeHTTP +type HttpMethodHandler struct { + Get http.Handler + Head bool + Post http.Handler + Put http.Handler + Patch http.Handler + Delete http.Handler + Connect http.Handler + Options http.Handler + Trace http.Handler + + allows string + allowsOnce sync.Once +} + +var _ http.Handler = (*HttpMethodHandler)(nil) + +// ServeHTTP implements http.Handler +// Once ServeHTTP is called the HttpMethodHandler should not be modified +func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodHead: + if !m.Head { + break + } + fallthrough + case http.MethodGet: + if m.Get != nil { + m.Get.ServeHTTP(rw, req) + return + } + case http.MethodPost: + if m.Post != nil { + m.Post.ServeHTTP(rw, req) + return + } + case http.MethodPut: + if m.Put != nil { + m.Put.ServeHTTP(rw, req) + return + } + case http.MethodPatch: + if m.Patch != nil { + m.Patch.ServeHTTP(rw, req) + return + } + case http.MethodDelete: + if m.Delete != nil { + m.Delete.ServeHTTP(rw, req) + return + } + case http.MethodConnect: + if m.Connect != nil { + m.Connect.ServeHTTP(rw, req) + return + } + case http.MethodOptions: + if m.Options != nil { + m.Options.ServeHTTP(rw, req) + return + } + case http.MethodTrace: + if m.Trace != nil { + m.Trace.ServeHTTP(rw, req) + return + } + } + m.allowsOnce.Do(func() { + allows := make([]string, 0, 5) + if m.Get != nil { + allows = append(allows, http.MethodGet) + if m.Head { + allows = append(allows, http.MethodGet) + } + } + if m.Post != nil { + allows = append(allows, http.MethodPost) + } + if m.Put != nil { + allows = append(allows, http.MethodPut) + } + if m.Patch != nil { + allows = append(allows, http.MethodPatch) + } + if m.Delete != nil { + allows = append(allows, http.MethodDelete) + } + if m.Connect != nil { + allows = append(allows, http.MethodConnect) + } + if m.Options != nil { + allows = append(allows, http.MethodOptions) + } + if m.Trace != nil { + allows = append(allows, http.MethodTrace) + } + m.allows = strings.Join(allows, ", ") + }) + rw.Header().Set("Allow", m.allows) + rw.WriteHeader(http.StatusMethodNotAllowed) +} + // HTTPTLSListener will serve a http or a tls connection // When Accept was called, if a pure http request is received, // it will response and redirect the client to the https protocol.