diff --git a/updates.go b/updates.go index 92a1ea4ab..37c797b8c 100644 --- a/updates.go +++ b/updates.go @@ -94,7 +94,11 @@ func (f *DBUpdates) Seq() string { return f.curVal.(*driver.DBUpdate).Seq } -// DBUpdates begins polling for database updates. +// DBUpdates begins polling for database updates. Canceling the context will +// close the iterator. The iterator will also close automatically if there are +// no more updates, when an error occurs, or when the [DBUpdates.Close] method +// is called. The [DBUpdates.Err] method should be consulted to determine if +// there was an error during iteration. func (c *Client) DBUpdates(ctx context.Context, options ...Option) *DBUpdates { updater, ok := c.driverClient.(driver.DBUpdater) if !ok { diff --git a/x/server/bind.go b/x/server/bind.go index 5f3c57dac..9697d1970 100644 --- a/x/server/bind.go +++ b/x/server/bind.go @@ -19,21 +19,54 @@ import ( "encoding/json" "mime" "net/http" + + "github.com/go-kivik/kivik/v4/internal" ) +// bind binds the request to v if it is of type application/json or +// application/x-www-form-urlencoded. func (s *Server) bind(r *http.Request, v interface{}) error { + defer r.Body.Close() + switch r.Method { + case http.MethodPatch, http.MethodPost, http.MethodPut: + // continue + default: + // simple query parsing + return s.bindForm(r, v) + } ct, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch ct { case "application/json": - defer r.Body.Close() - return json.NewDecoder(r.Body).Decode(v) - case "application/x-www-form-urlencoded": - defer r.Body.Close() - if err := r.ParseForm(); err != nil { - return err + if err := json.NewDecoder(r.Body).Decode(v); err != nil { + return &internal.Error{Status: http.StatusBadRequest, Err: err} } - return s.formDecoder.Decode(r.Form, v) + return nil + case "application/x-www-form-urlencoded": + return s.bindForm(r, v) default: return &couchError{status: http.StatusUnsupportedMediaType, Err: "bad_content_type", Reason: "Content-Type must be 'application/x-www-form-urlencoded' or 'application/json'"} } } + +func (s *Server) bindForm(r *http.Request, v interface{}) error { + defer r.Body.Close() + if err := r.ParseForm(); err != nil { + return &internal.Error{Status: http.StatusBadRequest, Err: err} + } + if err := s.formDecoder.Decode(r.Form, v); err != nil { + return &internal.Error{Status: http.StatusBadRequest, Err: err} + } + return nil +} + +// bindJSON works like bind, but for endpoints that require application/json. +func (s *Server) bindJSON(r *http.Request, v interface{}) error { + defer r.Body.Close() + ct, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch ct { + case "application/json": + return json.NewDecoder(r.Body).Decode(v) + default: + return &couchError{status: http.StatusUnsupportedMediaType, Err: "bad_content_type", Reason: "Content-Type must be 'application/json'"} + } +} diff --git a/x/server/cluster.go b/x/server/cluster.go new file mode 100644 index 000000000..1b04e6581 --- /dev/null +++ b/x/server/cluster.go @@ -0,0 +1,51 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//go:build !js +// +build !js + +package server + +import ( + "net/http" + + "gitlab.com/flimzy/httpe" +) + +func (s *Server) clusterStatus() httpe.HandlerWithError { + return httpe.HandlerWithErrorFunc(func(w http.ResponseWriter, r *http.Request) error { + status, err := s.client.ClusterStatus(r.Context(), options(r)) + if err != nil { + return err + } + return serveJSON(w, http.StatusOK, map[string]string{ + "state": status, + }) + }) +} + +func (s *Server) clusterSetup() httpe.HandlerWithError { + return httpe.HandlerWithErrorFunc(func(w http.ResponseWriter, r *http.Request) error { + var req struct { + Action string `json:"action"` + } + if err := s.bindJSON(r, &req); err != nil { + return err + } + if err := s.client.ClusterSetup(r.Context(), req.Action); err != nil { + return err + } + return serveJSON(w, http.StatusOK, map[string]bool{ + "ok": true, + }) + }) +} diff --git a/x/server/cluster_test.go b/x/server/cluster_test.go new file mode 100644 index 000000000..bf4de630f --- /dev/null +++ b/x/server/cluster_test.go @@ -0,0 +1,100 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//go:build !js +// +build !js + +package server + +import ( + "net/http" + "strings" + "testing" + + "github.com/go-kivik/kivik/v4" + "github.com/go-kivik/kivik/v4/mockdb" +) + +func Test_clusterStatus(t *testing.T) { + tests := serverTests{ + { + name: "cluster status, unauthorized", + method: http.MethodGet, + path: "/_cluster_setup", + wantStatus: http.StatusUnauthorized, + wantJSON: map[string]interface{}{ + "error": "unauthorized", + "reason": "User not authenticated", + }, + }, + { + name: "cluster status, success", + client: func() *kivik.Client { + client, mock, err := mockdb.New() + if err != nil { + t.Fatal(err) + } + mock.ExpectClusterStatus(). + WillReturn("chicken") + return client + }(), + method: http.MethodGet, + path: "/_cluster_setup", + authUser: userAdmin, + wantStatus: http.StatusOK, + wantJSON: map[string]string{ + "state": "chicken", + }, + }, + } + + tests.Run(t) +} + +func TestClusterSetup(t *testing.T) { + tests := serverTests{ + { + name: "cluster status, unauthorized", + method: http.MethodPost, + path: "/_cluster_setup", + wantStatus: http.StatusUnauthorized, + wantJSON: map[string]string{ + "error": "unauthorized", + "reason": "User not authenticated", + }, + }, + { + name: "cluster status, success", + client: func() *kivik.Client { + client, mock, err := mockdb.New() + if err != nil { + t.Fatal(err) + } + mock.ExpectClusterSetup(). + WithAction("chicken"). + WillReturnError(nil) + return client + }(), + method: http.MethodPost, + authUser: userAdmin, + path: "/_cluster_setup", + headers: map[string]string{"Content-Type": "application/json"}, + body: strings.NewReader(`{"action":"chicken"}`), + wantStatus: http.StatusOK, + wantJSON: map[string]bool{ + "ok": true, + }, + }, + } + + tests.Run(t) +} diff --git a/x/server/db.go b/x/server/db.go index b66688610..f92f3a4e9 100644 --- a/x/server/db.go +++ b/x/server/db.go @@ -16,10 +16,15 @@ package server import ( + "encoding/json" "net/http" + "strconv" + "time" "github.com/go-chi/chi/v5" "gitlab.com/flimzy/httpe" + + "github.com/go-kivik/kivik/v4/driver" ) func (s *Server) db() httpe.HandlerWithError { @@ -72,3 +77,93 @@ func (s *Server) deleteDB() httpe.HandlerWithError { }) }) } + +const defaultHeartbeat = heartbeat(60 * time.Second) + +type heartbeat time.Duration + +func (h *heartbeat) UnmarshalText(text []byte) error { + var value heartbeat + if string(text) == "true" { + value = defaultHeartbeat + } else { + ms, err := strconv.Atoi(string(text)) + if err != nil { + return err + } + value = heartbeat(ms) * heartbeat(time.Millisecond) + } + *h = value + return nil +} + +func (s *Server) dbUpdates() httpe.HandlerWithError { + return httpe.HandlerWithErrorFunc(func(w http.ResponseWriter, r *http.Request) error { + req := struct { + Heartbeat heartbeat `form:"heartbeat"` + }{ + Heartbeat: defaultHeartbeat, + } + if err := s.bind(r, &req); err != nil { + return err + } + ticker := time.NewTicker(time.Duration(req.Heartbeat)) + updates := s.client.DBUpdates(r.Context(), options(r)) + + if err := updates.Err(); err != nil { + return err + } + + defer updates.Close() + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + + if _, err := w.Write([]byte(`{"results":[`)); err != nil { + return err + } + + nextUpdate := make(chan *driver.DBUpdate) + go func() { + for updates.Next() { + nextUpdate <- &driver.DBUpdate{ + DBName: updates.DBName(), + Type: updates.Type(), + Seq: updates.Seq(), + } + } + close(nextUpdate) + }() + + var lastSeq string + loop: + for { + select { + case <-ticker.C: + if _, err := w.Write([]byte("\n")); err != nil { + return err + } + case update, ok := <-nextUpdate: + if !ok { + break loop + } + ticker.Reset(time.Duration(req.Heartbeat)) + if lastSeq != "" { + if _, err := w.Write([]byte(",")); err != nil { + return err + } + } + lastSeq = update.Seq + if err := json.NewEncoder(w).Encode(update); err != nil { + return err + } + } + } + + if _, err := w.Write([]byte(`],"last_seq":"` + lastSeq + "\"}")); err != nil { + return err + } + + return updates.Err() + }) +} diff --git a/x/server/db_test.go b/x/server/db_test.go new file mode 100644 index 000000000..f01f7ae07 --- /dev/null +++ b/x/server/db_test.go @@ -0,0 +1,121 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//go:build !js +// +build !js + +package server + +import ( + "net/http" + "testing" + "time" + + "github.com/go-kivik/kivik/v4" + "github.com/go-kivik/kivik/v4/driver" + "github.com/go-kivik/kivik/v4/mockdb" +) + +func Test_dbUpdates(t *testing.T) { + tests := serverTests{ + { + name: "db updates, unauthorized", + method: http.MethodGet, + path: "/_db_updates", + wantStatus: http.StatusUnauthorized, + wantJSON: map[string]interface{}{ + "error": "unauthorized", + "reason": "User not authenticated", + }, + }, + { + name: "db updates, two updates", + client: func() *kivik.Client { + client, mock, err := mockdb.New() + if err != nil { + t.Fatal(err) + } + mock.ExpectDBUpdates().WillReturn(mockdb.NewDBUpdates(). + AddUpdate(&driver.DBUpdate{ + DBName: "foo", + Type: "created", + Seq: "1-aaa", + }). + AddUpdate(&driver.DBUpdate{ + DBName: "foo", + Type: "deleted", + Seq: "2-aaa", + })) + return client + }(), + authUser: userAdmin, + method: http.MethodGet, + path: "/_db_updates", + wantStatus: http.StatusOK, + wantJSON: map[string]interface{}{ + "results": []interface{}{ + map[string]interface{}{ + "db_name": "foo", + "type": "created", + "seq": "1-aaa", + }, + map[string]interface{}{ + "db_name": "foo", + "type": "deleted", + "seq": "2-aaa", + }, + }, + "last_seq": "2-aaa", + }, + }, + { + name: "db updates, invalid heartbeat", + method: http.MethodGet, + authUser: userAdmin, + path: "/_db_updates?heartbeat=chicken", + wantStatus: http.StatusBadRequest, + wantJSON: map[string]interface{}{ + "error": "bad_request", + "reason": "strconv.Atoi: parsing \"chicken\": invalid syntax", + }, + }, + { + name: "db updates, with heartbeat", + client: func() *kivik.Client { + client, mock, err := mockdb.New() + if err != nil { + t.Fatal(err) + } + mock.ExpectDBUpdates().WillReturn(mockdb.NewDBUpdates(). + AddUpdate(&driver.DBUpdate{ + DBName: "foo", + Type: "created", + Seq: "1-aaa", + }). + AddDelay(500 * time.Millisecond). + AddUpdate(&driver.DBUpdate{ + DBName: "foo", + Type: "deleted", + Seq: "2-aaa", + })) + return client + }(), + authUser: userAdmin, + method: http.MethodGet, + path: "/_db_updates?heartbeat=100", + wantStatus: http.StatusOK, + wantBodyRE: "\n\n\n", + }, + } + + tests.Run(t) +} diff --git a/x/server/server.go b/x/server/server.go index f0c68ede5..2594aa789 100644 --- a/x/server/server.go +++ b/x/server/server.go @@ -55,6 +55,10 @@ type Server struct { sequentialUUIDMonotonicID int32 } +func e(h httpe.HandlerWithError) httpe.HandlerFunc { + return httpe.ToHandler(h).ServeHTTP +} + // New instantiates a new server instance. func New(client *kivik.Client, options ...Option) *Server { s := &Server{ @@ -85,49 +89,49 @@ func (s *Server) routes(mux *chi.Mux) { admin := auth.With( httpe.ToMiddleware(adminRequired), ) - auth.Get("/", httpe.ToHandler(s.root()).ServeHTTP) - admin.Get("/_active_tasks", httpe.ToHandler(s.activeTasks()).ServeHTTP) - admin.Get("/_all_dbs", httpe.ToHandler(s.allDBs()).ServeHTTP) - auth.Get("/_dbs_info", httpe.ToHandler(s.allDBsStats()).ServeHTTP) - auth.Post("/_dbs_info", httpe.ToHandler(s.dbsStats()).ServeHTTP) - auth.Get("/_cluster_setup", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Post("/_cluster_setup", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Post("/_db_updates", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_membership", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Post("/_replicate", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_scheduler/jobs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_scheduler/docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_scheduler/docs/{replicator_db}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_scheduler/docs/{replicator_db}/{doc_id}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_node/{node-name}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_node/{node-name}/_stats", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_node/{node-name}/_prometheus", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_node/{node-name}/_system", httpe.ToHandler(s.notImplemented()).ServeHTTP) - admin.Post("/_node/{node-name}/_restart", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_node/{node-name}/_versions", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Post("/_search_analyze", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_utils", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_utils/", httpe.ToHandler(s.notImplemented()).ServeHTTP) - mux.Get("/_up", httpe.ToHandler(s.up()).ServeHTTP) - mux.Get("/_uuids", httpe.ToHandler(s.uuids()).ServeHTTP) - mux.Get("/favicon.ico", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_reshard", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_reshard/state", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Put("/_reshard/state", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_reshard/jobs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_reshard/jobs/{jobid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Post("/_reshard/jobs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Delete("/_reshard/jobs/{jobid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Get("/_reshard/jobs/{jobid}/state", httpe.ToHandler(s.notImplemented()).ServeHTTP) - auth.Put("/_reshard/jobs/{jobid}/state", httpe.ToHandler(s.notImplemented()).ServeHTTP) + auth.Get("/", e(s.root())) + admin.Get("/_active_tasks", e(s.activeTasks())) + admin.Get("/_all_dbs", e(s.allDBs())) + auth.Get("/_dbs_info", e(s.allDBsStats())) + auth.Post("/_dbs_info", e(s.dbsStats())) + admin.Get("/_cluster_setup", e(s.clusterStatus())) + admin.Post("/_cluster_setup", e(s.clusterSetup())) + admin.Get("/_db_updates", e(s.dbUpdates())) + auth.Get("/_membership", e(s.notImplemented())) + auth.Post("/_replicate", e(s.notImplemented())) + auth.Get("/_scheduler/jobs", e(s.notImplemented())) + auth.Get("/_scheduler/docs", e(s.notImplemented())) + auth.Get("/_scheduler/docs/{replicator_db}", e(s.notImplemented())) + auth.Get("/_scheduler/docs/{replicator_db}/{doc_id}", e(s.notImplemented())) + auth.Get("/_node/{node-name}", e(s.notImplemented())) + auth.Get("/_node/{node-name}/_stats", e(s.notImplemented())) + auth.Get("/_node/{node-name}/_prometheus", e(s.notImplemented())) + auth.Get("/_node/{node-name}/_system", e(s.notImplemented())) + admin.Post("/_node/{node-name}/_restart", e(s.notImplemented())) + auth.Get("/_node/{node-name}/_versions", e(s.notImplemented())) + auth.Post("/_search_analyze", e(s.notImplemented())) + auth.Get("/_utils", e(s.notImplemented())) + auth.Get("/_utils/", e(s.notImplemented())) + mux.Get("/_up", e(s.up())) + mux.Get("/_uuids", e(s.uuids())) + mux.Get("/favicon.ico", e(s.notImplemented())) + auth.Get("/_reshard", e(s.notImplemented())) + auth.Get("/_reshard/state", e(s.notImplemented())) + auth.Put("/_reshard/state", e(s.notImplemented())) + auth.Get("/_reshard/jobs", e(s.notImplemented())) + auth.Get("/_reshard/jobs/{jobid}", e(s.notImplemented())) + auth.Post("/_reshard/jobs", e(s.notImplemented())) + auth.Delete("/_reshard/jobs/{jobid}", e(s.notImplemented())) + auth.Get("/_reshard/jobs/{jobid}/state", e(s.notImplemented())) + auth.Put("/_reshard/jobs/{jobid}/state", e(s.notImplemented())) // Config - admin.Get("/_node/{node-name}/_config", httpe.ToHandler(s.allConfig()).ServeHTTP) - admin.Get("/_node/{node-name}/_config/{section}", httpe.ToHandler(s.configSection()).ServeHTTP) - admin.Get("/_node/{node-name}/_config/{section}/{key}", httpe.ToHandler(s.configKey()).ServeHTTP) - admin.Put("/_node/{node-name}/_config/{section}/{key}", httpe.ToHandler(s.setConfigKey()).ServeHTTP) - admin.Delete("/_node/{node-name}/_config/{section}/{key}", httpe.ToHandler(s.deleteConfigKey()).ServeHTTP) - admin.Post("/_node/{node-name}/_config/_reload", httpe.ToHandler(s.reloadConfig()).ServeHTTP) + admin.Get("/_node/{node-name}/_config", e(s.allConfig())) + admin.Get("/_node/{node-name}/_config/{section}", e(s.configSection())) + admin.Get("/_node/{node-name}/_config/{section}/{key}", e(s.configKey())) + admin.Put("/_node/{node-name}/_config/{section}/{key}", e(s.setConfigKey())) + admin.Delete("/_node/{node-name}/_config/{section}/{key}", e(s.deleteConfigKey())) + admin.Post("/_node/{node-name}/_config/_reload", e(s.reloadConfig())) // Databases auth.Route("/{db}", func(db chi.Router) { @@ -141,82 +145,82 @@ func (s *Server) routes(mux *chi.Mux) { httpe.ToMiddleware(s.dbAdminRequired), ) - member.Head("/", httpe.ToHandler(s.dbExists()).ServeHTTP) - member.Get("/", httpe.ToHandler(s.db()).ServeHTTP) - admin.Put("/", httpe.ToHandler(s.createDB()).ServeHTTP) - admin.Delete("/", httpe.ToHandler(s.deleteDB()).ServeHTTP) - member.Get("/_all_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_all_docs/queries", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_all_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design_docs/queries", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_local_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_local_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_local_docs/queries", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_bulk_get", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_bulk_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_find", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_index", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_index", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Delete("/_index/{designdoc}/json/{name}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_explain", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_shards", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_shards/{docid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_sync_shards", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_changes", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_changes", httpe.ToHandler(s.notImplemented()).ServeHTTP) - admin.Post("/_compact", httpe.ToHandler(s.notImplemented()).ServeHTTP) - admin.Post("/_compact/{ddoc}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_ensure_full_commit", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_view_cleanup", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_security", httpe.ToHandler(s.getSecurity()).ServeHTTP) - dbAdmin.Put("/_security", httpe.ToHandler(s.putSecurity()).ServeHTTP) - member.Post("/_purge", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_purged_infos_limit", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Put("/_purged_infos_limit", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_missing_revs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_revs_diff", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_revs_limit", httpe.ToHandler(s.notImplemented()).ServeHTTP) - dbAdmin.Put("/_revs_limit", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Head("/", e(s.dbExists())) + member.Get("/", e(s.db())) + admin.Put("/", e(s.createDB())) + admin.Delete("/", e(s.deleteDB())) + member.Get("/_all_docs", e(s.notImplemented())) + member.Post("/_all_docs/queries", e(s.notImplemented())) + member.Post("/_all_docs", e(s.notImplemented())) + member.Get("/_design_docs", e(s.notImplemented())) + member.Post("/_design_docs", e(s.notImplemented())) + member.Post("/_design_docs/queries", e(s.notImplemented())) + member.Get("/_local_docs", e(s.notImplemented())) + member.Post("/_local_docs", e(s.notImplemented())) + member.Post("/_local_docs/queries", e(s.notImplemented())) + member.Post("/_bulk_get", e(s.notImplemented())) + member.Post("/_bulk_docs", e(s.notImplemented())) + member.Post("/_find", e(s.notImplemented())) + member.Post("/_index", e(s.notImplemented())) + member.Get("/_index", e(s.notImplemented())) + member.Delete("/_index/{designdoc}/json/{name}", e(s.notImplemented())) + member.Post("/_explain", e(s.notImplemented())) + member.Get("/_shards", e(s.notImplemented())) + member.Get("/_shards/{docid}", e(s.notImplemented())) + member.Get("/_sync_shards", e(s.notImplemented())) + member.Get("/_changes", e(s.notImplemented())) + member.Post("/_changes", e(s.notImplemented())) + admin.Post("/_compact", e(s.notImplemented())) + admin.Post("/_compact/{ddoc}", e(s.notImplemented())) + member.Post("/_ensure_full_commit", e(s.notImplemented())) + member.Post("/_view_cleanup", e(s.notImplemented())) + member.Get("/_security", e(s.getSecurity())) + dbAdmin.Put("/_security", e(s.putSecurity())) + member.Post("/_purge", e(s.notImplemented())) + member.Get("/_purged_infos_limit", e(s.notImplemented())) + member.Put("/_purged_infos_limit", e(s.notImplemented())) + member.Post("/_missing_revs", e(s.notImplemented())) + member.Post("/_revs_diff", e(s.notImplemented())) + member.Get("/_revs_limit", e(s.notImplemented())) + dbAdmin.Put("/_revs_limit", e(s.notImplemented())) // Documents - member.Post("/", httpe.ToHandler(s.postDoc()).ServeHTTP) - member.Get("/{docid}", httpe.ToHandler(s.doc()).ServeHTTP) - member.Put("/{docid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Delete("/{docid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Post("/", e(s.postDoc())) + member.Get("/{docid}", e(s.doc())) + member.Put("/{docid}", e(s.notImplemented())) + member.Delete("/{docid}", e(s.notImplemented())) member.Method("COPY", "/{db}/{docid}", httpe.ToHandler(s.notImplemented())) - member.Delete("/{docid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/{docid}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/{docid}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Delete("/{docid}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Delete("/{docid}", e(s.notImplemented())) + member.Get("/{docid}/{attname}", e(s.notImplemented())) + member.Get("/{docid}/{attname}", e(s.notImplemented())) + member.Delete("/{docid}/{attname}", e(s.notImplemented())) // Design docs - member.Get("/_design/{ddoc}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - dbAdmin.Put("/_design/{ddoc}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - dbAdmin.Delete("/_design/{ddoc}", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Get("/_design/{ddoc}", e(s.notImplemented())) + dbAdmin.Put("/_design/{ddoc}", e(s.notImplemented())) + dbAdmin.Delete("/_design/{ddoc}", e(s.notImplemented())) dbAdmin.Method("COPY", "/{db}/_design/{ddoc}", httpe.ToHandler(s.notImplemented())) - member.Get("/_design/{ddoc}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - dbAdmin.Put("/_design/{ddoc}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - dbAdmin.Delete("/_design/{ddoc}/{attname}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design/{ddoc}/_view/{view}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design/{ddoc}/_info", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design/{ddoc}/_view/{view}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design/{ddoc}/_view/{view}/queries", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design/{ddoc}/_search/{index}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design/{ddoc}/_search_info/{index}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design/{ddoc}/_update/{func}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design/{ddoc}/_update/{func}/{docid}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_design/{ddoc}/_rewrite/{path}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Put("/_design/{ddoc}/_rewrite/{path}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_design/{ddoc}/_rewrite/{path}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Delete("/_design/{ddoc}/_rewrite/{path}", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Get("/_design/{ddoc}/{attname}", e(s.notImplemented())) + dbAdmin.Put("/_design/{ddoc}/{attname}", e(s.notImplemented())) + dbAdmin.Delete("/_design/{ddoc}/{attname}", e(s.notImplemented())) + member.Get("/_design/{ddoc}/_view/{view}", e(s.notImplemented())) + member.Get("/_design/{ddoc}/_info", e(s.notImplemented())) + member.Post("/_design/{ddoc}/_view/{view}", e(s.notImplemented())) + member.Post("/_design/{ddoc}/_view/{view}/queries", e(s.notImplemented())) + member.Get("/_design/{ddoc}/_search/{index}", e(s.notImplemented())) + member.Get("/_design/{ddoc}/_search_info/{index}", e(s.notImplemented())) + member.Post("/_design/{ddoc}/_update/{func}", e(s.notImplemented())) + member.Post("/_design/{ddoc}/_update/{func}/{docid}", e(s.notImplemented())) + member.Get("/_design/{ddoc}/_rewrite/{path}", e(s.notImplemented())) + member.Put("/_design/{ddoc}/_rewrite/{path}", e(s.notImplemented())) + member.Post("/_design/{ddoc}/_rewrite/{path}", e(s.notImplemented())) + member.Delete("/_design/{ddoc}/_rewrite/{path}", e(s.notImplemented())) - member.Get("/_partition/{partition}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_partition/{partition}/_all_docs", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_partition/{partition}/_design/{ddoc}/_view/{view}", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Post("/_partition/{partition_id}/_find", httpe.ToHandler(s.notImplemented()).ServeHTTP) - member.Get("/_partition/{partition_id}/_explain", httpe.ToHandler(s.notImplemented()).ServeHTTP) + member.Get("/_partition/{partition}", e(s.notImplemented())) + member.Get("/_partition/{partition}/_all_docs", e(s.notImplemented())) + member.Get("/_partition/{partition}/_design/{ddoc}/_view/{view}", e(s.notImplemented())) + member.Post("/_partition/{partition_id}/_find", e(s.notImplemented())) + member.Get("/_partition/{partition_id}/_explain", e(s.notImplemented())) }) } diff --git a/x/server/server_test.go b/x/server/server_test.go index b24d2fa25..cb972a038 100644 --- a/x/server/server_test.go +++ b/x/server/server_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "os" + "regexp" "strings" "testing" "time" @@ -96,28 +97,130 @@ func basicAuth(user string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+testPassword)) } -func TestServer(t *testing.T) { - t.Parallel() - type test struct { - name string - driver, dsn string - init func(t *testing.T, client *kivik.Client) - extraOptions []Option - method string - path string - headers map[string]string - authUser string - body io.Reader - wantStatus int - wantJSON interface{} - check func(t *testing.T, client *kivik.Client) +type serverTest struct { + name string + client *kivik.Client + driver, dsn string + init func(t *testing.T, client *kivik.Client) + extraOptions []Option + method string + path string + headers map[string]string + authUser string + body io.Reader + wantStatus int + wantBodyRE string + wantJSON interface{} + check func(t *testing.T, client *kivik.Client) + + // if target is specified, it is expected to be a struct into which the + // response body will be unmarshaled, then validated. + target interface{} +} + +type serverTests []serverTest + +func (s serverTests) Run(t *testing.T) { + t.Helper() + for _, tt := range s { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + driver, dsn := "fs", "testdata/fsdb" + if tt.dsn != "" { + dsn = tt.dsn + } + client := tt.client + if client == nil { + if tt.driver != "" { + driver = tt.driver + } + if driver == "fs" { + dsn = testy.CopyTempDir(t, dsn, 0) + t.Cleanup(func() { + _ = os.RemoveAll(dsn) + }) + } + var err error + client, err = kivik.New(driver, dsn) + if err != nil { + t.Fatal(err) + } + } + if tt.init != nil { + tt.init(t, client) + } + us := testUserStore(t) + const secret = "foo" + opts := append([]Option{ + WithUserStores(us), + WithAuthHandlers(auth.BasicAuth()), + WithAuthHandlers(auth.CookieAuth(secret, time.Hour)), + }, tt.extraOptions...) + + s := New(client, opts...) + body := tt.body + if body == nil { + body = strings.NewReader("") + } + req, err := http.NewRequest(tt.method, tt.path, body) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.headers { + req.Header.Set(k, v) + } + if tt.authUser != "" { + user, err := us.UserCtx(context.Background(), tt.authUser) + if err != nil { + t.Fatal(err) + } + req.AddCookie(&http.Cookie{ + Name: kivik.SessionCookieName, + Value: auth.CreateAuthToken(user.Name, user.Salt, secret, time.Now().Unix()), + }) + } + + rec := httptest.NewRecorder() + s.ServeHTTP(rec, req) - // if target is specified, it is expected to be a struct into which the - // response body will be unmarshaled, then validated. - target interface{} + res := rec.Result() + if res.StatusCode != tt.wantStatus { + t.Errorf("Unexpected response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) + } + switch { + case tt.target != nil: + if err := json.NewDecoder(res.Body).Decode(tt.target); err != nil { + t.Fatal(err) + } + if err := v.Struct(tt.target); err != nil { + t.Fatalf("response does not match expectations: %s\n%v", err, tt.target) + } + case tt.wantBodyRE != "": + re := regexp.MustCompile(tt.wantBodyRE) + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !re.Match(body) { + t.Errorf("Unexpected response body:\n%s", body) + } + default: + if d := testy.DiffAsJSON(tt.wantJSON, res.Body); d != nil { + t.Error(d) + } + } + if tt.check != nil { + tt.check(t, client) + } + }) } +} + +func TestServer(t *testing.T) { + t.Parallel() - tests := []test{ + tests := serverTests{ { name: "root", method: http.MethodGet, @@ -703,9 +806,9 @@ func TestServer(t *testing.T) { }, }, }, - func() test { + func() serverTest { const want = `{"admins":{"names":["superuser"],"roles":["admins"]},"members":{"names":["user1","user2"],"roles":["developers"]}}` - return test{ + return serverTest{ name: "put security", method: http.MethodPut, path: "/db2/_security", @@ -838,82 +941,7 @@ func TestServer(t *testing.T) { }, } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - driver, dsn := "fs", "testdata/fsdb" - if tt.dsn != "" { - dsn = tt.dsn - } - if tt.driver != "" { - driver = tt.driver - } - if driver == "fs" { - dsn = testy.CopyTempDir(t, dsn, 0) - t.Cleanup(func() { - _ = os.RemoveAll(dsn) - }) - } - client, err := kivik.New(driver, dsn) - if err != nil { - t.Fatal(err) - } - if tt.init != nil { - tt.init(t, client) - } - us := testUserStore(t) - const secret = "foo" - opts := append([]Option{ - WithUserStores(us), - WithAuthHandlers(auth.BasicAuth()), - WithAuthHandlers(auth.CookieAuth(secret, time.Hour)), - }, tt.extraOptions...) - - s := New(client, opts...) - req, err := http.NewRequest(tt.method, tt.path, tt.body) - if err != nil { - t.Fatal(err) - } - for k, v := range tt.headers { - req.Header.Set(k, v) - } - if tt.authUser != "" { - user, err := us.UserCtx(context.Background(), tt.authUser) - if err != nil { - t.Fatal(err) - } - req.AddCookie(&http.Cookie{ - Name: kivik.SessionCookieName, - Value: auth.CreateAuthToken(user.Name, user.Salt, secret, time.Now().Unix()), - }) - } - - rec := httptest.NewRecorder() - s.ServeHTTP(rec, req) - - res := rec.Result() - if res.StatusCode != tt.wantStatus { - t.Errorf("Unexpected response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) - } - switch { - case tt.target != nil: - if err := json.NewDecoder(res.Body).Decode(tt.target); err != nil { - t.Fatal(err) - } - if err := v.Struct(tt.target); err != nil { - t.Fatalf("response does not match expectations: %s\n%v", err, tt.target) - } - default: - if d := testy.DiffAsJSON(tt.wantJSON, res.Body); d != nil { - t.Error(d) - } - } - if tt.check != nil { - tt.check(t, client) - } - }) - } + tests.Run(t) } type readOnlyConfig struct {