diff --git a/config/confpar/confpar.go b/config/confpar/confpar.go index f4bbc367..ee0f18c8 100644 --- a/config/confpar/confpar.go +++ b/config/confpar/confpar.go @@ -1,6 +1,8 @@ // Package confpar provide the core parameters of the config package confpar +import "time" + // Access provides rules around any access type Access struct { User string `json:"user"` // User authenticating @@ -13,6 +15,13 @@ type Access struct { SyncAndDelete *SyncAndDelete `json:"sync_and_delete"` // Local empty directory and synchronization } +// AccessesWebhook defines an optional webhook to get user's access +type AccessesWebhook struct { + URL string `json:"url"` // URL to call + Headers map[string]string `json:"headers"` // Token to use in the + Timeout time.Duration `json:"timeout"` // Max time request can take +} + // SyncAndDelete provides type SyncAndDelete struct { Enable bool `json:"enable"` // Instant write @@ -46,13 +55,14 @@ type ServerCert struct { // Content defines the content of the config file type Content struct { - Version int `json:"version"` // File format version - ListenAddress string `json:"listen_address"` // Address to listen on - PublicHost string `json:"public_host"` // Public host to listen on - MaxClients int `json:"max_clients"` // Maximum clients who can connect - HashPlaintextPasswords bool `json:"hash_plaintext_passwords"` // Overwrite plain-text passwords with hashed equivalents - Accesses []*Access `json:"accesses"` // Accesses offered to users - PassiveTransferPortRange *PortRange `json:"passive_transfer_port_range"` // Listen port range - Logging Logging `json:"logging"` // Logging parameters - TLS *TLS `json:"tls"` // TLS Config + Version int `json:"version"` // File format version + ListenAddress string `json:"listen_address"` // Address to listen on + PublicHost string `json:"public_host"` // Public host to listen on + MaxClients int `json:"max_clients"` // Maximum clients who can connect + HashPlaintextPasswords bool `json:"hash_plaintext_passwords"` // Overwrite plain-text passwords with hashed equivalents + Accesses []*Access `json:"accesses"` // Accesses offered to users + PassiveTransferPortRange *PortRange `json:"passive_transfer_port_range"` // Listen port range + Logging Logging `json:"logging"` // Logging parameters + TLS *TLS `json:"tls"` // TLS Config + AccessesWebhook *AccessesWebhook `json:"accesses_webhook"` // Webhook to call when accesses are updated } diff --git a/server/server.go b/server/server.go index 36bc33e8..3d9ac132 100644 --- a/server/server.go +++ b/server/server.go @@ -2,10 +2,15 @@ package server import ( + "bytes" + "context" "crypto/tls" + "encoding/json" "errors" "fmt" + "io" "io/ioutil" + "net/http" "sync" "time" @@ -169,9 +174,73 @@ func (s *Server) loadFs(access *confpar.Access) (afero.Fs, error) { return newFs, err } +func (s *Server) getAccessFromWebhook(user, pass string) (*confpar.Access, error) { + // Convert payload to JSON + jsonData, err := json.Marshal(map[string]string{ + "user": user, + "pass": pass, + }) + if err != nil { + return nil, err + } + + // Timeout is implemented with context termination + ctx, cancel := context.WithTimeout(context.Background(), s.config.Content.AccessesWebhook.Timeout) + defer cancel() + + // Create a new HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", s.config.Content.AccessesWebhook.URL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + for key, value := range s.config.Content.AccessesWebhook.Headers { + req.Header.Set(key, value) + } + + // Execute the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Check the response status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // Return the response + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + access := new(confpar.Access) + err = json.Unmarshal(body, &access) + if err != nil { + return nil, err + } + + return access, nil +} + // AuthUser authenticates the user and selects an handling driver func (s *Server) AuthUser(cc serverlib.ClientContext, user, pass string) (serverlib.ClientDriver, error) { - access, errAccess := s.config.GetAccess(user, pass) + var ( + access *confpar.Access + errAccess error + ) + + if s.config.Content.AccessesWebhook == nil { + // Get the access from the configuration + access, errAccess = s.config.GetAccess(user, pass) + } else { + // Get the access from the webhook, not the configuration + access, errAccess = s.getAccessFromWebhook(user, pass) + } if errAccess != nil { return nil, errAccess }