From 20fbb63ef61c88abf90612517bae9921f09e226a Mon Sep 17 00:00:00 2001 From: Paul Lorenz Date: Tue, 9 Jul 2024 14:16:41 -0400 Subject: [PATCH] Merge fabric and controller model code. Fixes #2205 --- .../datastructures}/cmap_iterator.go | 2 +- controller/api_impl/circuit_api_model.go | 5 +- controller/api_impl/inspections_router.go | 4 +- controller/api_impl/link_api_model.go | 5 +- controller/api_impl/router_api_model.go | 17 +- controller/api_impl/router_router.go | 15 +- controller/api_impl/service_api_model.go | 17 +- controller/api_impl/service_router.go | 15 +- controller/api_impl/terminator_api_model.go | 25 +- controller/api_impl/terminator_router.go | 13 +- controller/config.go | 925 -------------- controller/config/config.go | 1120 +++++++++++------ controller/config/config_edge.go | 564 +++++++++ .../{config_test.go => config_edge_test.go} | 0 .../options.go => config/config_network.go} | 12 +- controller/config/config_raft.go | 31 + controller/controller.go | 44 +- controller/env/appenv.go | 86 +- controller/env/broker.go | 8 +- controller/env/sync.go | 5 +- controller/events/dispatcher_router.go | 7 +- controller/events/dispatcher_terminator.go | 11 +- controller/handler_ctrl/base.go | 3 +- controller/handler_ctrl/bind.go | 5 +- .../handler_ctrl/circuit_confirmation.go | 9 +- controller/handler_ctrl/circuit_request.go | 19 +- controller/handler_ctrl/close.go | 5 +- controller/handler_ctrl/create_terminator.go | 9 +- controller/handler_ctrl/decommission.go | 5 +- controller/handler_ctrl/dequiesce_router.go | 7 +- controller/handler_ctrl/fault.go | 5 +- controller/handler_ctrl/link_connected.go | 7 +- controller/handler_ctrl/quiesce_router.go | 7 +- controller/handler_ctrl/remove_terminator.go | 9 +- controller/handler_ctrl/remove_terminators.go | 5 +- controller/handler_ctrl/route_result.go | 9 +- controller/handler_ctrl/router_link.go | 5 +- controller/handler_ctrl/update_terminator.go | 11 +- controller/handler_ctrl/verify_router.go | 7 +- controller/handler_edge_ctrl/common.go | 32 +- controller/handler_edge_ctrl/common_tunnel.go | 2 +- .../handler_edge_ctrl/create_circuit.go | 8 +- .../handler_edge_ctrl/create_terminator.go | 5 +- .../handler_edge_ctrl/create_terminator_v2.go | 13 +- .../create_tunnel_terminator.go | 13 +- .../handler_edge_ctrl/remove_terminator.go | 2 +- .../remove_tunnel_terminator.go | 2 +- .../handler_edge_ctrl/validate_sessions.go | 2 +- controller/handler_mgmt/inspect.go | 2 +- .../handler_mgmt/stream_toggle_pipe_traces.go | 3 +- .../handler_mgmt/validate_terminators.go | 2 +- .../policy/service_policy_enforcer.go | 4 +- .../internal/routes/authenticate_router.go | 2 +- controller/internal/routes/base_router.go | 23 +- controller/internal/routes/ca_router.go | 2 +- .../routes/current_api_session_router.go | 4 +- .../current_identity_authenticator_router.go | 8 +- controller/internal/routes/database_router.go | 4 +- controller/internal/routes/enroll_router.go | 2 +- controller/internal/routes/protocol_router.go | 2 +- .../internal/routes/router_api_model.go | 4 +- .../internal/routes/service_api_model.go | 12 +- controller/internal/routes/service_router.go | 6 +- controller/internal/routes/summary_router.go | 2 +- .../internal/routes/terminator_api_model.go | 33 +- .../internal/routes/terminator_router.go | 10 +- .../model/api_session_certificate_manager.go | 2 - controller/model/api_session_manager.go | 14 +- controller/model/api_session_model.go | 4 +- controller/model/auth_policy_manager.go | 7 +- controller/model/authenticator_manager.go | 15 +- controller/model/authenticator_mod_ext_jwt.go | 4 +- controller/model/base_manager.go | 8 +- controller/model/ca_manager.go | 11 +- controller/{network => model}/circuit.go | 36 +- controller/{network => model}/command.go | 99 +- controller/{network => model}/command_test.go | 29 +- controller/model/config_manager.go | 7 +- controller/model/config_type_manager.go | 7 +- controller/model/controller_manager.go | 11 +- controller/model/create_terminator_cmd.go | 7 +- controller/model/edge_router_manager.go | 7 +- controller/model/edge_router_manager_test.go | 2 +- .../model/edge_router_policy_manager.go | 7 +- controller/model/edge_service_manager.go | 33 +- controller/model/edge_service_model.go | 12 +- controller/model/enrollment_manager.go | 13 +- controller/model/enrollment_mod_erott.go | 2 +- controller/model/enrollment_mod_trott.go | 2 +- controller/model/enrollment_model.go | 4 +- controller/model/env.go | 27 +- .../model/external_jwt_signer_manager.go | 7 +- controller/model/identity_manager.go | 9 +- .../link_manager.go} | 205 +-- .../link_manager_test.go} | 30 +- .../{network/link.go => model/link_model.go} | 12 +- controller/model/managers.go | 106 +- controller/model/mfa_manager.go | 9 +- controller/model/path.go | 96 ++ controller/model/policy_advisor.go | 6 +- controller/model/posture_check_manager.go | 7 +- controller/model/posture_response_manager.go | 10 +- controller/model/posture_response_model.go | 6 +- controller/model/revocation_manager.go | 5 +- controller/model/revocation_model.go | 2 +- .../router.go => model/router_manager.go} | 164 +-- controller/model/router_model.go | 110 ++ .../service_edge_router_policy_manager.go | 7 +- .../service.go => model/service_manager.go} | 94 +- controller/model/service_model.go | 66 + controller/model/service_policy_manager.go | 7 +- .../terminator_manager.go} | 181 +-- controller/model/terminator_model.go | 116 ++ controller/model/testing.go | 209 ++- controller/model/transit_router_manager.go | 5 +- controller/network/assembly.go | 23 +- controller/network/circuit_lifecycle.go | 9 +- controller/network/db_provider.go | 12 - controller/network/fault.go | 5 +- controller/network/handler.go | 6 +- controller/network/inspect.go | 3 +- controller/network/managers.go | 328 ----- controller/network/network.go | 425 ++++--- .../network/{path.go => network_path.go} | 214 ++-- controller/network/network_test.go | 32 +- controller/network/path_test.go | 305 +++-- controller/network/route_perf_test.go | 62 +- controller/network/router_messaging.go | 43 +- controller/network/routesender.go | 19 +- controller/network/routesender_test.go | 28 +- controller/network/smart.go | 16 +- controller/network/smart_test.go | 32 +- controller/network/util_test.go | 25 +- controller/oidc_auth/storage.go | 4 +- controller/raft/raft.go | 158 +-- controller/server/client-api.go | 9 +- controller/server/controller.go | 19 +- controller/settings.go | 15 +- controller/subcmd/init.go | 9 +- controller/sync_strats/marshal.go | 2 +- controller/sync_strats/rtx.go | 7 +- controller/sync_strats/sync_instant.go | 17 +- tests/ca_traffic_test.go | 4 +- tests/context.go | 49 +- tests/control.go | 4 +- tests/data_flow_test.go | 1 - tests/enrollment_router_test.go | 8 +- tests/fabric_context.go | 5 +- tests/mfa_ziti_test.go | 4 +- tests/transit_router_test.go | 9 +- ziti/cmd/create/create_config.go | 9 +- ziti/cmd/database/add_debug_admin.go | 25 +- ziti/controller/delete_sessions.go | 4 +- ziti/controller/run.go | 9 +- zititest/zitilab/models/db_builder.go | 5 - 155 files changed, 3425 insertions(+), 3639 deletions(-) rename {controller/network => common/datastructures}/cmap_iterator.go (98%) delete mode 100644 controller/config.go create mode 100644 controller/config/config_edge.go rename controller/config/{config_test.go => config_edge_test.go} (100%) rename controller/{network/options.go => config/config_network.go} (97%) create mode 100644 controller/config/config_raft.go rename controller/{network => model}/circuit.go (79%) rename controller/{network => model}/command.go (52%) rename controller/{network => model}/command_test.go (70%) rename controller/{network/link_controller.go => model/link_manager.go} (56%) rename controller/{network/link_controller_test.go => model/link_manager_test.go} (75%) rename controller/{network/link.go => model/link_model.go} (95%) create mode 100644 controller/model/path.go rename controller/{network/router.go => model/router_manager.go} (74%) create mode 100644 controller/model/router_model.go rename controller/{network/service.go => model/service_manager.go} (64%) create mode 100644 controller/model/service_model.go rename controller/{network/terminator.go => model/terminator_manager.go} (75%) create mode 100644 controller/model/terminator_model.go delete mode 100644 controller/network/db_provider.go delete mode 100644 controller/network/managers.go rename controller/network/{path.go => network_path.go} (50%) diff --git a/controller/network/cmap_iterator.go b/common/datastructures/cmap_iterator.go similarity index 98% rename from controller/network/cmap_iterator.go rename to common/datastructures/cmap_iterator.go index 19bba4bc8..4303f2acf 100644 --- a/controller/network/cmap_iterator.go +++ b/common/datastructures/cmap_iterator.go @@ -14,7 +14,7 @@ limitations under the License. */ -package network +package datastructures import ( "github.com/openziti/storage/objectz" diff --git a/controller/api_impl/circuit_api_model.go b/controller/api_impl/circuit_api_model.go index a0400eb0f..fcca12f95 100644 --- a/controller/api_impl/circuit_api_model.go +++ b/controller/api_impl/circuit_api_model.go @@ -18,6 +18,7 @@ package api_impl import ( "github.com/openziti/ziti/controller/api" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_model" @@ -44,7 +45,7 @@ func (factory *CircuitLinkFactoryIml) Links(entity LinkEntity) rest_model.Links return links } -func MapCircuitToRestModel(n *network.Network, _ api.RequestContext, circuit *network.Circuit) (*rest_model.CircuitDetail, error) { +func MapCircuitToRestModel(n *network.Network, _ api.RequestContext, circuit *model.Circuit) (*rest_model.CircuitDetail, error) { path := &rest_model.Path{} for _, node := range circuit.Path.Nodes { path.Nodes = append(path.Nodes, ToEntityRef(node.Name, node, RouterLinkFactory)) @@ -54,7 +55,7 @@ func MapCircuitToRestModel(n *network.Network, _ api.RequestContext, circuit *ne } var svcEntityRef *rest_model.EntityRef - if svc, _ := n.Services.Read(circuit.ServiceId); svc != nil { + if svc, _ := n.Service.Read(circuit.ServiceId); svc != nil { svcEntityRef = ToEntityRef(svc.Name, svc, ServiceLinkFactory) } else { svcEntityRef = ToEntityRef("", deletedEntity(circuit.ServiceId), ServiceLinkFactory) diff --git a/controller/api_impl/inspections_router.go b/controller/api_impl/inspections_router.go index b44225f3f..2677b0089 100644 --- a/controller/api_impl/inspections_router.go +++ b/controller/api_impl/inspections_router.go @@ -18,12 +18,12 @@ package api_impl import ( "github.com/go-openapi/runtime/middleware" + "github.com/openziti/foundation/v2/stringz" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_model" "github.com/openziti/ziti/controller/rest_server/operations" "github.com/openziti/ziti/controller/rest_server/operations/inspect" - "github.com/openziti/foundation/v2/stringz" "net/http" ) @@ -49,7 +49,7 @@ func (r *InspectRouter) Register(fabricApi *operations.ZitiFabricAPI, wrapper Re } func (r *InspectRouter) Inspect(n *network.Network, rc api.RequestContext, request *rest_model.InspectRequest) { - result := n.Managers.Inspections.Inspect(stringz.OrEmpty(request.AppRegex), request.RequestedValues) + result := n.Inspections.Inspect(stringz.OrEmpty(request.AppRegex), request.RequestedValues) resp := MapInspectResultToRestModel(n, result) rc.Respond(resp, http.StatusOK) } diff --git a/controller/api_impl/link_api_model.go b/controller/api_impl/link_api_model.go index d8338813b..9289a15be 100644 --- a/controller/api_impl/link_api_model.go +++ b/controller/api_impl/link_api_model.go @@ -18,6 +18,7 @@ package api_impl import ( "github.com/openziti/ziti/controller/api" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_model" ) @@ -41,7 +42,7 @@ func (factory *LinkLinkFactoryIml) Links(entity LinkEntity) rest_model.Links { return links } -func MapLinkToRestModel(n *network.Network, _ api.RequestContext, link *network.Link) (*rest_model.LinkDetail, error) { +func MapLinkToRestModel(n *network.Network, _ api.RequestContext, link *model.Link) (*rest_model.LinkDetail, error) { iteration := int64(link.Iteration) staticCost := int64(link.StaticCost) linkStateStr := link.CurrentState().Mode.String() @@ -51,7 +52,7 @@ func MapLinkToRestModel(n *network.Network, _ api.RequestContext, link *network. destRouter := link.GetDest() if destRouter == nil { var err error - destRouter, err = n.Routers.Read(link.DstId) + destRouter, err = n.Router.Read(link.DstId) if err != nil { return nil, err } diff --git a/controller/api_impl/router_api_model.go b/controller/api_impl/router_api_model.go index de159f713..3b4ffdbf0 100644 --- a/controller/api_impl/router_api_model.go +++ b/controller/api_impl/router_api_model.go @@ -18,12 +18,13 @@ package api_impl import ( "github.com/openziti/ziti/controller/api" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_model" - "github.com/openziti/ziti/controller/models" "github.com/openziti/foundation/v2/stringz" + "github.com/openziti/ziti/controller/models" ) const EntityNameRouter = "routers" @@ -46,8 +47,8 @@ func (factory *RouterLinkFactoryIml) Links(entity LinkEntity) rest_model.Links { return links } -func MapCreateRouterToModel(router *rest_model.RouterCreate) *network.Router { - ret := &network.Router{ +func MapCreateRouterToModel(router *rest_model.RouterCreate) *model.Router { + ret := &model.Router{ BaseEntity: models.BaseEntity{ Id: stringz.OrEmpty(router.ID), Tags: TagsOrDefault(router.Tags), @@ -62,8 +63,8 @@ func MapCreateRouterToModel(router *rest_model.RouterCreate) *network.Router { return ret } -func MapUpdateRouterToModel(id string, router *rest_model.RouterUpdate) *network.Router { - ret := &network.Router{ +func MapUpdateRouterToModel(id string, router *rest_model.RouterUpdate) *model.Router { + ret := &model.Router{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(router.Tags), Id: id, @@ -78,8 +79,8 @@ func MapUpdateRouterToModel(id string, router *rest_model.RouterUpdate) *network return ret } -func MapPatchRouterToModel(id string, router *rest_model.RouterPatch) *network.Router { - ret := &network.Router{ +func MapPatchRouterToModel(id string, router *rest_model.RouterPatch) *model.Router { + ret := &model.Router{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(router.Tags), Id: id, @@ -96,7 +97,7 @@ func MapPatchRouterToModel(id string, router *rest_model.RouterPatch) *network.R type RouterModelMapper struct{} -func (RouterModelMapper) ToApi(n *network.Network, _ api.RequestContext, router *network.Router) (interface{}, error) { +func (RouterModelMapper) ToApi(n *network.Network, _ api.RequestContext, router *model.Router) (interface{}, error) { connected := n.GetConnectedRouter(router.Id) var restVersionInfo *rest_model.VersionInfo if connected != nil && connected.VersionInfo != nil { diff --git a/controller/api_impl/router_router.go b/controller/api_impl/router_router.go index 2917689a5..7d4bc3f5c 100644 --- a/controller/api_impl/router_router.go +++ b/controller/api_impl/router_router.go @@ -20,6 +20,7 @@ import ( "github.com/go-openapi/runtime/middleware" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_server/operations" "github.com/openziti/ziti/controller/rest_server/operations/router" @@ -71,17 +72,17 @@ func (r *RouterRouter) Register(fabricApi *operations.ZitiFabricAPI, wrapper Req } func (r *RouterRouter) ListRouters(n *network.Network, rc api.RequestContext) { - ListWithHandler[*network.Router](n, rc, n.Managers.Routers, RouterModelMapper{}) + ListWithHandler[*model.Router](n, rc, n.Managers.Router, RouterModelMapper{}) } func (r *RouterRouter) Detail(n *network.Network, rc api.RequestContext) { - DetailWithHandler[*network.Router](n, rc, n.Managers.Routers, RouterModelMapper{}) + DetailWithHandler[*model.Router](n, rc, n.Managers.Router, RouterModelMapper{}) } func (r *RouterRouter) Create(n *network.Network, rc api.RequestContext, params router.CreateRouterParams) { Create(rc, RouterLinkFactory, func() (string, error) { router := MapCreateRouterToModel(params.Router) - err := n.Routers.Create(router, rc.NewChangeContext()) + err := n.Router.Create(router, rc.NewChangeContext()) if err != nil { return "", err } @@ -90,21 +91,21 @@ func (r *RouterRouter) Create(n *network.Network, rc api.RequestContext, params } func (r *RouterRouter) Delete(network *network.Network, rc api.RequestContext) { - DeleteWithHandler(rc, network.Managers.Routers) + DeleteWithHandler(rc, network.Managers.Router) } func (r *RouterRouter) Update(n *network.Network, rc api.RequestContext, params router.UpdateRouterParams) { Update(rc, func(id string) error { - return n.Managers.Routers.Update(MapUpdateRouterToModel(params.ID, params.Router), nil, rc.NewChangeContext()) + return n.Managers.Router.Update(MapUpdateRouterToModel(params.ID, params.Router), nil, rc.NewChangeContext()) }) } func (r *RouterRouter) Patch(n *network.Network, rc api.RequestContext, params router.PatchRouterParams) { Patch(rc, func(id string, fields fields.UpdatedFields) error { - return n.Managers.Routers.Update(MapPatchRouterToModel(params.ID, params.Router), fields.FilterMaps("tags"), rc.NewChangeContext()) + return n.Managers.Router.Update(MapPatchRouterToModel(params.ID, params.Router), fields.FilterMaps("tags"), rc.NewChangeContext()) }) } func (r *RouterRouter) listManagementTerminators(n *network.Network, rc api.RequestContext) { - ListAssociationWithHandler[*network.Router, *network.Terminator](n, rc, n.Managers.Routers, n.Managers.Terminators, TerminatorModelMapper{}) + ListAssociationWithHandler[*model.Router, *model.Terminator](n, rc, n.Managers.Router, n.Managers.Terminator, TerminatorModelMapper{}) } diff --git a/controller/api_impl/service_api_model.go b/controller/api_impl/service_api_model.go index db1932b2b..e031274a2 100644 --- a/controller/api_impl/service_api_model.go +++ b/controller/api_impl/service_api_model.go @@ -19,12 +19,13 @@ package api_impl import ( "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/idgen" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_model" - "github.com/openziti/ziti/controller/models" "github.com/openziti/foundation/v2/stringz" + "github.com/openziti/ziti/controller/models" ) const EntityNameService = "services" @@ -47,8 +48,8 @@ func (factory *ServiceLinkFactoryIml) Links(entity LinkEntity) rest_model.Links return links } -func MapCreateServiceToModel(service *rest_model.ServiceCreate) *network.Service { - ret := &network.Service{ +func MapCreateServiceToModel(service *rest_model.ServiceCreate) *model.Service { + ret := &model.Service{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), }, @@ -67,8 +68,8 @@ func MapCreateServiceToModel(service *rest_model.ServiceCreate) *network.Service return ret } -func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *network.Service { - ret := &network.Service{ +func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *model.Service { + ret := &model.Service{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), Id: id, @@ -80,8 +81,8 @@ func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *netw return ret } -func MapPatchServiceToModel(id string, service *rest_model.ServicePatch) *network.Service { - ret := &network.Service{ +func MapPatchServiceToModel(id string, service *rest_model.ServicePatch) *model.Service { + ret := &model.Service{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), Id: id, @@ -95,7 +96,7 @@ func MapPatchServiceToModel(id string, service *rest_model.ServicePatch) *networ type ServiceModelMapper struct{} -func (ServiceModelMapper) ToApi(_ *network.Network, _ api.RequestContext, service *network.Service) (interface{}, error) { +func (ServiceModelMapper) ToApi(_ *network.Network, _ api.RequestContext, service *model.Service) (interface{}, error) { return &rest_model.ServiceDetail{ BaseEntity: BaseEntityToRestModel(service, ServiceLinkFactory), Name: &service.Name, diff --git a/controller/api_impl/service_router.go b/controller/api_impl/service_router.go index 99a767e0a..28560fb7e 100644 --- a/controller/api_impl/service_router.go +++ b/controller/api_impl/service_router.go @@ -20,6 +20,7 @@ import ( "github.com/go-openapi/runtime/middleware" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_server/operations" "github.com/openziti/ziti/controller/rest_server/operations/service" @@ -71,17 +72,17 @@ func (r *ServiceRouter) Register(fabricApi *operations.ZitiFabricAPI, wrapper Re } func (r *ServiceRouter) ListServices(n *network.Network, rc api.RequestContext) { - ListWithHandler[*network.Service](n, rc, n.Managers.Services, ServiceModelMapper{}) + ListWithHandler[*model.Service](n, rc, n.Managers.Service, ServiceModelMapper{}) } func (r *ServiceRouter) Detail(n *network.Network, rc api.RequestContext) { - DetailWithHandler[*network.Service](n, rc, n.Managers.Services, ServiceModelMapper{}) + DetailWithHandler[*model.Service](n, rc, n.Managers.Service, ServiceModelMapper{}) } func (r *ServiceRouter) Create(n *network.Network, rc api.RequestContext, params service.CreateServiceParams) { Create(rc, ServiceLinkFactory, func() (string, error) { svc := MapCreateServiceToModel(params.Service) - err := n.Services.Create(svc, rc.NewChangeContext()) + err := n.Service.Create(svc, rc.NewChangeContext()) if err != nil { return "", err } @@ -90,21 +91,21 @@ func (r *ServiceRouter) Create(n *network.Network, rc api.RequestContext, params } func (r *ServiceRouter) Delete(network *network.Network, rc api.RequestContext) { - DeleteWithHandler(rc, network.Managers.Services) + DeleteWithHandler(rc, network.Managers.Service) } func (r *ServiceRouter) Update(n *network.Network, rc api.RequestContext, params service.UpdateServiceParams) { Update(rc, func(id string) error { - return n.Managers.Services.Update(MapUpdateServiceToModel(params.ID, params.Service), nil, rc.NewChangeContext()) + return n.Managers.Service.Update(MapUpdateServiceToModel(params.ID, params.Service), nil, rc.NewChangeContext()) }) } func (r *ServiceRouter) Patch(n *network.Network, rc api.RequestContext, params service.PatchServiceParams) { Patch(rc, func(id string, fields fields.UpdatedFields) error { - return n.Managers.Services.Update(MapPatchServiceToModel(params.ID, params.Service), fields.FilterMaps("tags"), rc.NewChangeContext()) + return n.Managers.Service.Update(MapPatchServiceToModel(params.ID, params.Service), fields.FilterMaps("tags"), rc.NewChangeContext()) }) } func (r *ServiceRouter) listManagementTerminators(n *network.Network, rc api.RequestContext) { - ListAssociationWithHandler[*network.Service, *network.Terminator](n, rc, n.Managers.Services, n.Managers.Terminators, TerminatorModelMapper{}) + ListAssociationWithHandler[*model.Service, *model.Terminator](n, rc, n.Managers.Service, n.Managers.Terminator, TerminatorModelMapper{}) } diff --git a/controller/api_impl/terminator_api_model.go b/controller/api_impl/terminator_api_model.go index d318cf82c..e1669eb95 100644 --- a/controller/api_impl/terminator_api_model.go +++ b/controller/api_impl/terminator_api_model.go @@ -19,20 +19,21 @@ package api_impl import ( "fmt" "github.com/michaelquigley/pfxlog" + "github.com/openziti/foundation/v2/stringz" "github.com/openziti/ziti/controller/api" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" "github.com/openziti/ziti/controller/network" - "github.com/openziti/ziti/controller/xt" "github.com/openziti/ziti/controller/rest_model" - "github.com/openziti/foundation/v2/stringz" + "github.com/openziti/ziti/controller/xt" ) const EntityNameTerminator = "terminators" var TerminatorLinkFactory = NewBasicLinkFactory(EntityNameTerminator) -func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *network.Terminator { - ret := &network.Terminator{ +func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), }, @@ -53,8 +54,8 @@ func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *networ return ret } -func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpdate) *network.Terminator { - ret := &network.Terminator{ +func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpdate) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), Id: id, @@ -74,8 +75,8 @@ func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpda return ret } -func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch) *network.Terminator { - ret := &network.Terminator{ +func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), Id: id, @@ -97,7 +98,7 @@ func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch type TerminatorModelMapper struct{} -func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, terminator *network.Terminator) (interface{}, error) { +func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, terminator *model.Terminator) (interface{}, error) { restModel, err := MapTerminatorToRestModel(n, terminator) if err != nil { @@ -109,13 +110,13 @@ func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, ter return restModel, nil } -func MapTerminatorToRestModel(n *network.Network, terminator *network.Terminator) (*rest_model.TerminatorDetail, error) { - service, err := n.Managers.Services.Read(terminator.Service) +func MapTerminatorToRestModel(n *network.Network, terminator *model.Terminator) (*rest_model.TerminatorDetail, error) { + service, err := n.Managers.Service.Read(terminator.Service) if err != nil { return nil, err } - router, err := n.Managers.Routers.Read(terminator.Router) + router, err := n.Managers.Router.Read(terminator.Router) if err != nil { return nil, err } diff --git a/controller/api_impl/terminator_router.go b/controller/api_impl/terminator_router.go index c855b1ba9..785995a30 100644 --- a/controller/api_impl/terminator_router.go +++ b/controller/api_impl/terminator_router.go @@ -20,6 +20,7 @@ import ( "github.com/go-openapi/runtime/middleware" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/rest_server/operations" "github.com/openziti/ziti/controller/rest_server/operations/terminator" @@ -67,17 +68,17 @@ func (r *TerminatorRouter) Register(fabricApi *operations.ZitiFabricAPI, wrapper } func (r *TerminatorRouter) List(n *network.Network, rc api.RequestContext) { - ListWithHandler[*network.Terminator](n, rc, n.Managers.Terminators, TerminatorModelMapper{}) + ListWithHandler[*model.Terminator](n, rc, n.Managers.Terminator, TerminatorModelMapper{}) } func (r *TerminatorRouter) Detail(n *network.Network, rc api.RequestContext) { - DetailWithHandler[*network.Terminator](n, rc, n.Managers.Terminators, TerminatorModelMapper{}) + DetailWithHandler[*model.Terminator](n, rc, n.Managers.Terminator, TerminatorModelMapper{}) } func (r *TerminatorRouter) Create(n *network.Network, rc api.RequestContext, params terminator.CreateTerminatorParams) { Create(rc, TerminatorLinkFactory, func() (string, error) { entity := MapCreateTerminatorToModel(params.Terminator) - err := n.Terminators.Create(entity, rc.NewChangeContext()) + err := n.Terminator.Create(entity, rc.NewChangeContext()) if err != nil { return "", err } @@ -86,17 +87,17 @@ func (r *TerminatorRouter) Create(n *network.Network, rc api.RequestContext, par } func (r *TerminatorRouter) Delete(n *network.Network, rc api.RequestContext) { - DeleteWithHandler(rc, n.Managers.Terminators) + DeleteWithHandler(rc, n.Managers.Terminator) } func (r *TerminatorRouter) Update(n *network.Network, rc api.RequestContext, params terminator.UpdateTerminatorParams) { Update(rc, func(id string) error { - return n.Managers.Terminators.Update(MapUpdateTerminatorToModel(params.ID, params.Terminator), nil, rc.NewChangeContext()) + return n.Managers.Terminator.Update(MapUpdateTerminatorToModel(params.ID, params.Terminator), nil, rc.NewChangeContext()) }) } func (r *TerminatorRouter) Patch(n *network.Network, rc api.RequestContext, params terminator.PatchTerminatorParams) { Patch(rc, func(id string, fields fields.UpdatedFields) error { - return n.Managers.Terminators.Update(MapPatchTerminatorToModel(params.ID, params.Terminator), fields.FilterMaps("tags"), rc.NewChangeContext()) + return n.Managers.Terminator.Update(MapPatchTerminatorToModel(params.ID, params.Terminator), fields.FilterMaps("tags"), rc.NewChangeContext()) }) } diff --git a/controller/config.go b/controller/config.go deleted file mode 100644 index 0379f7bbd..000000000 --- a/controller/config.go +++ /dev/null @@ -1,925 +0,0 @@ -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package controller - -import ( - "bytes" - "crypto/sha1" - "crypto/tls" - "crypto/x509" - "fmt" - "github.com/hashicorp/go-hclog" - "github.com/michaelquigley/pfxlog" - "github.com/openziti/channel/v2" - "github.com/openziti/identity" - "github.com/openziti/storage/boltz" - "github.com/openziti/transport/v2" - transporttls "github.com/openziti/transport/v2/tls" - "github.com/openziti/ziti/common/config" - "github.com/openziti/ziti/common/pb/ctrl_pb" - "github.com/openziti/ziti/common/pb/mgmt_pb" - "github.com/openziti/ziti/controller/command" - "github.com/openziti/ziti/controller/db" - "github.com/openziti/ziti/controller/network" - "github.com/openziti/ziti/controller/raft" - "github.com/openziti/ziti/router/xgress" - "github.com/pkg/errors" - "gopkg.in/yaml.v2" - "math" - "net/url" - "os" - "strings" - "time" -) - -const ( - DefaultProfileMemoryInterval = 15 * time.Second - DefaultHealthChecksBoltCheckInterval = 30 * time.Second - DefaultHealthChecksBoltCheckTimeout = 20 * time.Second - DefaultHealthChecksBoltCheckInitialDelay = 30 * time.Second - - DefaultRaftCommandHandlerMaxQueueSize = 1000 - - // DefaultTlsHandshakeRateLimiterEnabled is whether the tls handshake rate limiter is enabled by default - DefaultTlsHandshakeRateLimiterEnabled = false - - // TlsHandshakeRateLimiterMinSizeValue is the minimum size that can be configured for the tls handshake rate limiter - // window range - TlsHandshakeRateLimiterMinSizeValue = 5 - - // TlsHandshakeRateLimiterMaxSizeValue is the maximum size that can be configured for the tls handshake rate limiter - // window range - TlsHandshakeRateLimiterMaxSizeValue = 10000 - - // TlsHandshakeRateLimiterMetricOutstandingCount is the name of the metric tracking how many tasks are in process - TlsHandshakeRateLimiterMetricOutstandingCount = "tls_handshake_limiter.in_process" - - // TlsHandshakeRateLimiterMetricCurrentWindowSize is the name of the metric tracking the current window size - TlsHandshakeRateLimiterMetricCurrentWindowSize = "tls_handshake_limiter.window_size" - - // TlsHandshakeRateLimiterMetricWorkTimer is the name of the metric tracking how long successful tasks are taking to complete - TlsHandshakeRateLimiterMetricWorkTimer = "tls_handshake_limiter.work_timer" - - // DefaultTlsHandshakeRateLimiterMaxWindow is the default max size for the tls handshake rate limiter - DefaultTlsHandshakeRateLimiterMaxWindow = 1000 -) - -type Config struct { - Id *identity.TokenId - SpiffeIdTrustDomain *url.URL - AdditionalTrustDomains []*url.URL - - Raft *raft.Config - Network *network.Options - Db boltz.Db - Trace struct { - Handler *channel.TraceHandler - } - Profile struct { - Memory struct { - Path string - Interval time.Duration - } - CPU struct { - Path string - } - } - Ctrl struct { - Listener transport.Address - Options *CtrlOptions - } - HealthChecks struct { - BoltCheck struct { - Interval time.Duration - Timeout time.Duration - InitialDelay time.Duration - } - } - CommandRateLimiter command.RateLimiterConfig - TlsHandshakeRateLimiter command.AdaptiveRateLimiterConfig - src map[interface{}]interface{} -} - -// CtrlOptions extends channel.Options to include support for additional, non-channel specific options -// (e.g. NewListener) -type CtrlOptions struct { - *channel.Options - NewListener *transport.Address - AdvertiseAddress *transport.Address - RouterHeartbeatOptions *channel.HeartbeatOptions - PeerHeartbeatOptions *channel.HeartbeatOptions -} - -func (config *Config) Configure(sub config.Subconfig) error { - return sub.LoadConfig(config.src) -} - -func LoadConfig(path string) (*Config, error) { - cfgBytes, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - cfgmap := make(map[interface{}]interface{}) - if err = yaml.NewDecoder(bytes.NewReader(cfgBytes)).Decode(&cfgmap); err != nil { - return nil, err - } - config.InjectEnv(cfgmap) - if value, found := cfgmap["v"]; found { - if value.(int) != 3 { - panic("config version mismatch: see docs for information on config updates") - } - } else { - panic("no config version: see docs for information on config") - } - - var identityConfig *identity.Config - - if value, found := cfgmap["identity"]; found { - subMap := value.(map[interface{}]interface{}) - identityConfig, err = identity.NewConfigFromMapWithPathContext(subMap, "identity") - - if err != nil { - return nil, fmt.Errorf("could not parse root identity: %v", err) - } - - if identityConfig.ServerCert == "" && identityConfig.ServerKey == "" { - identityConfig.ServerCert = identityConfig.Cert - identityConfig.ServerKey = identityConfig.Key - } - } else { - return nil, fmt.Errorf("identity section not found") - } - - controllerConfig := &Config{ - Network: network.DefaultOptions(), - src: cfgmap, - } - - if id, err := identity.LoadIdentity(*identityConfig); err != nil { - return nil, fmt.Errorf("unable to load identity (%s)", err) - } else { - controllerConfig.Id = identity.NewIdentity(id) - - if err := controllerConfig.Id.WatchFiles(); err != nil { - pfxlog.Logger().Warn("could not enable file watching on identity: %w", err) - } - } - - if value, found := cfgmap["network"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if options, err := network.LoadOptions(submap); err == nil { - controllerConfig.Network = options - } else { - return nil, fmt.Errorf("invalid 'network' stanza (%s)", err) - } - } else { - pfxlog.Logger().Warn("invalid or empty 'network' stanza") - } - } - - if value, found := cfgmap["raft"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - controllerConfig.Raft = &raft.Config{} - controllerConfig.Raft.CommandHandlerOptions.MaxQueueSize = DefaultRaftCommandHandlerMaxQueueSize - - if value, found := submap["dataDir"]; found { - controllerConfig.Raft.DataDir = value.(string) - } else { - return nil, errors.Errorf("raft dataDir configuration missing") - } - if value, found := submap["minClusterSize"]; found { - controllerConfig.Raft.MinClusterSize = uint32(value.(int)) - } - if value, found := submap["bootstrapMembers"]; found { - if lst, ok := value.([]interface{}); ok { - for idx, val := range lst { - if member, ok := val.(string); ok { - controllerConfig.Raft.BootstrapMembers = append(controllerConfig.Raft.BootstrapMembers, member) - } else { - return nil, errors.Errorf("invalid bootstrapMembers value '%v'at index %v, should be array", idx, val) - } - } - } else { - return nil, errors.New("invalid bootstrapMembers value, should be array") - } - } - - if value, found := submap["snapshotInterval"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.Raft.SnapshotInterval = &val - } else { - return nil, errors.Wrapf(err, "failed to parse raft.snapshotInterval value '%v", value) - } - } - - if value, found := submap["commitTimeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.Raft.CommitTimeout = &val - } else { - return nil, errors.Wrapf(err, "failed to parse raft.commitTimeout value '%v", value) - } - } - - if value, found := submap["electionTimeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.Raft.ElectionTimeout = &val - } else { - return nil, errors.Wrapf(err, "failed to parse raft.electionTimeout value '%v", value) - } - } - - if value, found := submap["heartbeatTimeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.Raft.HeartbeatTimeout = &val - } else { - return nil, errors.Wrapf(err, "failed to parse raft.heartbeatTimeout value '%v", value) - } - } - - if value, found := submap["leaderLeaseTimeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.Raft.LeaderLeaseTimeout = &val - } else { - return nil, errors.Wrapf(err, "failed to parse raft.leaderLeaseTimeout value '%v", value) - } - } - - if value, found := submap["snapshotThreshold"]; found { - val := uint32(value.(int)) - controllerConfig.Raft.SnapshotThreshold = &val - } - - if value, found := submap["maxAppendEntries"]; found { - val := uint32(value.(int)) - controllerConfig.Raft.MaxAppendEntries = &val - } - - if value, found := submap["trailingLogs"]; found { - val := uint32(value.(int)) - controllerConfig.Raft.TrailingLogs = &val - } - - if value, found := submap["logLevel"]; found { - val := fmt.Sprintf("%v", value) - if hclog.LevelFromString(val) == hclog.NoLevel { - return nil, errors.Errorf("invalid value for raft.logLevel [%v]", val) - } - controllerConfig.Raft.LogLevel = &val - } - - if value, found := submap["logFile"]; found { - val := fmt.Sprintf("%v", value) - options := *hclog.DefaultOptions - f, err := os.Create(val) - if err != nil { - return nil, errors.Wrapf(err, "unable to open raft log file [%v]", val) - } - options.Output = f - if controllerConfig.Raft.LogLevel != nil { - options.Level = hclog.LevelFromString(*controllerConfig.Raft.LogLevel) - } - controllerConfig.Raft.Logger = hclog.New(&options) - } - - if value, found := cfgmap["commandHandler"]; found { - if chSubMap, ok := value.(map[interface{}]interface{}); ok { - if value, found := chSubMap["maxQueueSize"]; found { - controllerConfig.Raft.CommandHandlerOptions.MaxQueueSize = uint16(value.(int)) - } - } else { - return nil, errors.New("invalid commandHandler value, should be map") - } - } - } else { - return nil, errors.Errorf("invalid raft configuration") - } - } else if value, found := cfgmap["db"]; found { - str, err := db.Open(value.(string)) - if err != nil { - return nil, err - } - controllerConfig.Db = str - } else { - panic("controllerConfig must provide [db] or [raft]") - } - - //SPIFFE Trust Domain - var spiffeId *url.URL - if controllerConfig.Raft != nil { - //HA setup, SPIFFE ID must come from certs - var err error - spiffeId, err = GetSpiffeIdFromIdentity(controllerConfig.Id.Identity) - if err != nil { - panic("error determining a trust domain from a SPIFFE id in the root identity for HA configuration, must have a spiffe:// URI SANs in the server certificate or along the signing CAs chain: " + err.Error()) - } - - if spiffeId == nil { - panic("unable to determine a trust domain from a SPIFFE id in the root identity for HA configuration, must have a spiffe:// URI SANs in the server certificate or along the signing CAs chain") - } - } else { - // Non-HA/legacy system, prefer SPIFFE id from certs, but fall back to configuration if necessary - spiffeId, _ = GetSpiffeIdFromIdentity(controllerConfig.Id.Identity) - - if spiffeId == nil { - //for non HA setups allow the trust domain to come from the configuration root value `trustDomain` - if value, found := cfgmap["trustDomain"]; found { - trustDomain, ok := value.(string) - - if !ok { - panic(fmt.Sprintf("could not parse [trustDomain], expected a string got [%T]", value)) - } - - if trustDomain != "" { - if !strings.HasPrefix("spiffe://", trustDomain) { - trustDomain = "spiffe://" + trustDomain - } - - spiffeId, err = url.Parse(trustDomain) - - if err != nil { - panic("could not parse [trustDomain] when used in a SPIFFE id URI [" + trustDomain + "], please make sure it is a valid URI hostname: " + err.Error()) - } - - if spiffeId == nil { - panic("could not parse [trustDomain] when used in a SPIFFE id URI [" + trustDomain + "]: spiffeId is nil and no error returned") - } - - if spiffeId.Scheme != "spiffe" { - panic("[trustDomain] does not have a spiffe scheme (spiffe://) has: " + spiffeId.Scheme) - } - } - } - } - - //default a generated trust domain and spiffe id from the sha1 of the root ca - if spiffeId == nil { - spiffeId, err = generateDefaultSpiffeId(controllerConfig.Id.Identity) - - if err != nil { - panic("could not generate default trust domain: " + err.Error()) - } - - pfxlog.Logger().Warnf("this environment is using a default generated trust domain [%s], it is recommended that a trust domain is specified in configuration via URI SANs or the 'trustDomain' field", spiffeId.String()) - pfxlog.Logger().Warnf("this environment is using a default generated trust domain [%s], it is recommended that if network components have enrolled that the generated trust domain be added to the configuration field 'additionalTrustDomains' array when configuring a explicit trust domain", spiffeId.String()) - } - } - - if spiffeId == nil { - panic("unable to determine trust domain from SPIFFE id (spiffe:// URI SANs in server cert or signing CAs) or from configuration [trustDomain], controllers must have a trust domain") - } - - if spiffeId.Hostname() == "" { - panic("unable to determine trust domain from SPIFFE id: hostname was empty") - } - - //only preserve trust domain - spiffeId.Path = "" - controllerConfig.SpiffeIdTrustDomain = spiffeId - - if value, found := cfgmap["additionalTrustDomains"]; found { - if valArr, ok := value.([]any); ok { - var trustDomains []*url.URL - for _, trustDomain := range valArr { - if strTrustDomain, ok := trustDomain.(string); ok { - - if !strings.HasPrefix("spiffe://", strTrustDomain) { - strTrustDomain = "spiffe://" + strTrustDomain - } - - spiffeId, err = url.Parse(strTrustDomain) - - if err != nil { - panic(fmt.Sprintf("invalid entry in 'additionalTrustDomains', could not be parsed as a URI: %v", trustDomain)) - } - //only preserve trust domain - spiffeId.Path = "" - - trustDomains = append(trustDomains, spiffeId) - } else { - panic(fmt.Sprintf("invalid entry in 'additionalTrustDomains' expected a string: %v", trustDomain)) - } - } - - controllerConfig.AdditionalTrustDomains = trustDomains - } - } - - if value, found := cfgmap["trace"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["path"]; found { - handler, err := channel.NewTraceHandler(value.(string), controllerConfig.Id.Token) - if err != nil { - return nil, err - } - handler.AddDecoder(&channel.Decoder{}) - handler.AddDecoder(&ctrl_pb.Decoder{}) - handler.AddDecoder(&xgress.Decoder{}) - handler.AddDecoder(&mgmt_pb.Decoder{}) - controllerConfig.Trace.Handler = handler - } - } - } - - if value, found := cfgmap["profile"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["memory"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["path"]; found { - controllerConfig.Profile.Memory.Path = value.(string) - } - if value, found := submap["intervalMs"]; found { - controllerConfig.Profile.Memory.Interval = time.Duration(value.(int)) * time.Millisecond - } else { - controllerConfig.Profile.Memory.Interval = DefaultProfileMemoryInterval - } - } - } - if value, found := submap["cpu"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["path"]; found { - controllerConfig.Profile.CPU.Path = value.(string) - } - } - } - } - } - - if value, found := cfgmap["ctrl"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["listener"]; found { - listener, err := transport.ParseAddress(value.(string)) - if err != nil { - return nil, err - } - controllerConfig.Ctrl.Listener = listener - } else { - panic("controllerConfig must provide [ctrl/listener]") - } - - controllerConfig.Ctrl.Options = &CtrlOptions{ - Options: channel.DefaultOptions(), - PeerHeartbeatOptions: channel.DefaultHeartbeatOptions(), - RouterHeartbeatOptions: channel.DefaultHeartbeatOptions(), - } - - if value, found := submap["options"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - options, err := channel.LoadOptions(submap) - if err != nil { - return nil, err - } - - controllerConfig.Ctrl.Options.Options = options - - if val, found := submap["newListener"]; found { - if newListener, ok := val.(string); ok { - if newListener != "" { - if addr, err := transport.ParseAddress(newListener); err == nil { - controllerConfig.Ctrl.Options.NewListener = &addr - - if err := verifyNewListenerInServerCert(controllerConfig, addr); err != nil { - return nil, err - } - - } else { - return nil, fmt.Errorf("error loading newListener for [ctrl/options] (%v)", err) - } - } - } else { - return nil, errors.New("error loading newAddress for [ctrl/options] (must be a string)") - } - } - - if val, found := submap["advertiseAddress"]; found { - if advertiseAddr, ok := val.(string); ok { - if advertiseAddr != "" { - addr, err := transport.ParseAddress(advertiseAddr) - if err != nil { - return nil, errors.Wrapf(err, "error parsing value '%v' for [ctrl/options/advertiseAddress]", advertiseAddr) - } - controllerConfig.Ctrl.Options.AdvertiseAddress = &addr - if controllerConfig.Raft != nil { - controllerConfig.Raft.AdvertiseAddress = addr - } - } - } else { - return nil, errors.New("error loading advertiseAddress for [ctrl/options] (must be a string)") - } - } - - if value, found := submap["routerHeartbeats"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - options, err := channel.LoadHeartbeatOptions(submap) - if err != nil { - return nil, err - } - controllerConfig.Ctrl.Options.RouterHeartbeatOptions = options - } - } - - if value, found := submap["peerHeartbeats"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - options, err := channel.LoadHeartbeatOptions(submap) - if err != nil { - return nil, err - } - controllerConfig.Ctrl.Options.PeerHeartbeatOptions = options - } - } - - if err := controllerConfig.Ctrl.Options.Validate(); err != nil { - return nil, fmt.Errorf("error loading channel options for [ctrl/options] (%v)", err) - } - } - } - if controllerConfig.Raft != nil && controllerConfig.Raft.AdvertiseAddress == nil { - return nil, errors.New("[ctrl/options/advertiseAddress] is required when raft is enabled") - } - } else { - panic("controllerConfig [ctrl] section in unexpected format") - } - } else { - panic("controllerConfig must provide [ctrl]") - } - - controllerConfig.HealthChecks.BoltCheck.Interval = DefaultHealthChecksBoltCheckInterval - controllerConfig.HealthChecks.BoltCheck.Timeout = DefaultHealthChecksBoltCheckTimeout - controllerConfig.HealthChecks.BoltCheck.InitialDelay = DefaultHealthChecksBoltCheckInitialDelay - - if value, found := cfgmap["healthChecks"]; found { - if healthChecksMap, ok := value.(map[interface{}]interface{}); ok { - if value, found := healthChecksMap["boltCheck"]; found { - if boltMap, ok := value.(map[interface{}]interface{}); ok { - if value, found := boltMap["interval"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.HealthChecks.BoltCheck.Interval = val - } else { - return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.interval value '%v", value) - } - } - - if value, found := boltMap["timeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.HealthChecks.BoltCheck.Timeout = val - } else { - return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.timeout value '%v", value) - } - } - - if value, found := boltMap["initialDelay"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - controllerConfig.HealthChecks.BoltCheck.InitialDelay = val - } else { - return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.initialDelay value '%v", value) - } - } - } else { - pfxlog.Logger().Warn("invalid [healthChecks.bolt] stanza") - } - } - } else { - pfxlog.Logger().Warn("invalid [healthChecks] stanza") - } - } - - controllerConfig.CommandRateLimiter.Enabled = true - controllerConfig.CommandRateLimiter.QueueSize = command.DefaultLimiterSize - - if value, found := cfgmap["commandRateLimiter"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if value, found := submap["enabled"]; found { - controllerConfig.CommandRateLimiter.Enabled = strings.EqualFold("true", fmt.Sprintf("%v", value)) - } - - if value, found := submap["maxQueued"]; found { - if intVal, ok := value.(int); ok { - v := int64(intVal) - if v < command.MinLimiterSize { - return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be at least %v", value, command.MinLimiterSize) - } - if v > math.MaxUint32 { - return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be at most %v", value, int64(math.MaxUint32)) - } - controllerConfig.CommandRateLimiter.QueueSize = uint32(v) - } else { - return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be integer value", value) - } - } - } - } - - controllerConfig.TlsHandshakeRateLimiter.SetDefaults() - controllerConfig.TlsHandshakeRateLimiter.Enabled = DefaultTlsHandshakeRateLimiterEnabled - controllerConfig.TlsHandshakeRateLimiter.MaxSize = DefaultTlsHandshakeRateLimiterMaxWindow - controllerConfig.TlsHandshakeRateLimiter.QueueSizeMetric = TlsHandshakeRateLimiterMetricOutstandingCount - controllerConfig.TlsHandshakeRateLimiter.WindowSizeMetric = TlsHandshakeRateLimiterMetricCurrentWindowSize - controllerConfig.TlsHandshakeRateLimiter.WorkTimerMetric = TlsHandshakeRateLimiterMetricWorkTimer - - if value, found := cfgmap["tls"]; found { - if tlsMap, ok := value.(map[interface{}]interface{}); ok { - if value, found := tlsMap["handshakeTimeout"]; found { - if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { - transporttls.SetSharedListenerHandshakeTimeout(val) - } else { - return nil, errors.Wrapf(err, "failed to parse tls.handshakeTimeout value '%v", value) - } - } - if err = loadTlsHandshakeRateLimiterConfig(&controllerConfig.TlsHandshakeRateLimiter, tlsMap); err != nil { - return nil, err - } - } - } - - return controllerConfig, nil -} - -// isSelfSigned checks if the given certificate is self-signed. -func isSelfSigned(cert *x509.Certificate) (bool, error) { - // Check if the Issuer and Subject fields are equal - if cert.Issuer.String() != cert.Subject.String() { - return false, nil - } - - // Attempt to verify the certificate's signature with its own public key - err := cert.CheckSignatureFrom(cert) - if err != nil { - return false, err - } - - return true, nil -} - -func generateDefaultSpiffeId(id identity.Identity) (*url.URL, error) { - chain := id.CaPool().GetChain(id.Cert().Leaf) - - // chain is 0 or 1, no root possible - if len(chain) <= 1 { - return nil, fmt.Errorf("error generating default trust domain from root CA: no root CA detected after chain assembly from the root identity server cert and ca bundle") - } - - candidateRoot := chain[len(chain)-1] - - if candidateRoot == nil { - return nil, fmt.Errorf("encountered nil candidate root ca during default trust domain generation") - } - - if !candidateRoot.IsCA { - return nil, fmt.Errorf("candidate root CA is not flagged with the x509 CA flag") - } - - if selfSigned, _ := isSelfSigned(candidateRoot); !selfSigned { - return nil, errors.New("candidate root CA is not self signed") - } - - rawHash := sha1.Sum(candidateRoot.Raw) - - fingerprint := fmt.Sprintf("%x", rawHash) - idStr := "spiffe://" + fingerprint - - spiffeId, err := url.Parse(idStr) - - if err != nil { - return nil, fmt.Errorf("could not parse generated SPIFFE id [%s] as a URI: %w", idStr, err) - } - - return spiffeId, nil -} - -// GetSpiffeIdFromIdentity will search an Identity for a trust domain encoded as a spiffe:// URI SAN starting -// from the server cert and up its signing chain. Each certificate must contain 0 or 1 spiffe:// URI SAN. The first -// SPIFFE id looking up the chain back to the root CA is returned. If no SPIFFE id is encountered, nil is returned. -// Errors are returned for parsing and processing errors only. -func GetSpiffeIdFromIdentity(id identity.Identity) (*url.URL, error) { - tlsCerts := id.ServerCert() - - spiffeId, err := GetSpiffeIdFromTlsCertChain(tlsCerts) - - if err != nil { - return nil, fmt.Errorf("failed to acquire SPIFFE id from server certs: %w", err) - } - - if spiffeId != nil { - return spiffeId, nil - } - - if len(tlsCerts) > 0 { - chain := id.CaPool().GetChain(tlsCerts[0].Leaf) - - if len(chain) > 0 { - spiffeId, _ = GetSpiffeIdFromCertChain(chain) - } - } - - if spiffeId == nil { - return nil, errors.Errorf("SPIFFE id not found in identity") - } - - return spiffeId, nil -} - -// GetSpiffeIdFromCertChain cycles through a slice of certificates that goes from leaf up CAs. Each certificate -// must contain 0 or 1 spiffe:// URI SAN. The first encountered SPIFFE id looking up the chain back to the root CA is returned. -// If no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. -func GetSpiffeIdFromCertChain(certs []*x509.Certificate) (*url.URL, error) { - var spiffeId *url.URL - for _, cert := range certs { - var err error - spiffeId, err = GetSpiffeIdFromCert(cert) - - if err != nil { - return nil, fmt.Errorf("failed to determine SPIFFE ID from x509 certificate chain: %w", err) - } - - if spiffeId != nil { - return spiffeId, nil - } - } - - return nil, errors.New("failed to determine SPIFFE ID, no spiffe:// URI SANs found in x509 certificate chain") -} - -// GetSpiffeIdFromTlsCertChain will search a tls certificate chain for a trust domain encoded as a spiffe:// URI SAN. -// Each certificate must contain 0 or 1 spiffe:// URI SAN. The first SPIFFE id looking up the chain is returned. If -// no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. -func GetSpiffeIdFromTlsCertChain(tlsCerts []*tls.Certificate) (*url.URL, error) { - for _, tlsCert := range tlsCerts { - for i, rawCert := range tlsCert.Certificate { - cert, err := x509.ParseCertificate(rawCert) - - if err != nil { - return nil, fmt.Errorf("failed to parse TLS cert at index [%d]: %w", i, err) - } - - spiffeId, err := GetSpiffeIdFromCert(cert) - - if err != nil { - return nil, fmt.Errorf("failed to determine SPIFFE ID from TLS cert at index [%d]: %w", i, err) - } - - if spiffeId != nil { - return spiffeId, nil - } - } - } - - return nil, nil -} - -// GetSpiffeIdFromCert will search a x509 certificate for a trust domain encoded as a spiffe:// URI SAN. -// Each certificate must contain 0 or 1 spiffe:// URI SAN. The first SPIFFE id looking up the chain is returned. If -// no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. -func GetSpiffeIdFromCert(cert *x509.Certificate) (*url.URL, error) { - var spiffeId *url.URL - for _, uriSan := range cert.URIs { - if uriSan.Scheme == "spiffe" { - if spiffeId != nil { - return nil, fmt.Errorf("multiple URI SAN spiffe:// ids encountered, must only have one, encountered at least two: [%s] and [%s]", spiffeId.String(), uriSan.String()) - } - spiffeId = uriSan - } - } - - return spiffeId, nil -} - -func loadTlsHandshakeRateLimiterConfig(rateLimitConfig *command.AdaptiveRateLimiterConfig, cfgmap map[interface{}]interface{}) error { - if value, found := cfgmap["rateLimiter"]; found { - if submap, ok := value.(map[interface{}]interface{}); ok { - if err := command.LoadAdaptiveRateLimiterConfig(rateLimitConfig, submap); err != nil { - return err - } - if rateLimitConfig.MaxSize < TlsHandshakeRateLimiterMinSizeValue { - return errors.Errorf("invalid value %v for tls.rateLimiter.maxSize, must be at least %v", - rateLimitConfig.MaxSize, TlsHandshakeRateLimiterMinSizeValue) - } - if rateLimitConfig.MaxSize > TlsHandshakeRateLimiterMaxSizeValue { - return errors.Errorf("invalid value %v for tls.rateLimiter.maxSize, must be at most %v", - rateLimitConfig.MaxSize, TlsHandshakeRateLimiterMaxSizeValue) - } - - if rateLimitConfig.MinSize < TlsHandshakeRateLimiterMinSizeValue { - return errors.Errorf("invalid value %v for tls.rateLimiter.minSize, must be at least %v", - rateLimitConfig.MinSize, TlsHandshakeRateLimiterMinSizeValue) - } - if rateLimitConfig.MinSize > TlsHandshakeRateLimiterMaxSizeValue { - return errors.Errorf("invalid value %v for tls.rateLimiter.minSize, must be at most %v", - rateLimitConfig.MinSize, TlsHandshakeRateLimiterMaxSizeValue) - } - } else { - return errors.Errorf("invalid type for tls.rateLimiter, should be map instead of %T", value) - } - } - - return nil -} - -// verifyNewListenerInServerCert verifies that the hostname (ip/dns) for addr is present as an IP/DNS SAN in the first -// certificate provided in the controller's identity server certificates. This is to avoid scenarios where -// newListener propagated to routers who will never be able to verify the controller's certificates due to SAN issues. -func verifyNewListenerInServerCert(controllerConfig *Config, addr transport.Address) error { - addrSplits := strings.Split(addr.String(), ":") - if len(addrSplits) < 3 { - return errors.New("could not determine newListener's host value, expected at least three segments") - } - - host := addrSplits[1] - - serverCerts := controllerConfig.Id.Identity.ServerCert() - - if len(serverCerts) == 0 { - return errors.New("could not verify newListener value, server certificate for identity contains no certificates") - } - - hostFound := false - for _, serverCert := range serverCerts { - for _, dnsName := range serverCert.Leaf.DNSNames { - if dnsName == host { - hostFound = true - break - } - } - - if hostFound { - break - } - - if !hostFound { - for _, ipAddresses := range serverCert.Leaf.IPAddresses { - if host == ipAddresses.String() { - hostFound = true - break - } - } - } - - if hostFound { - break - } - } - - if !hostFound { - return fmt.Errorf("could not find newListener [%s] host value [%s] in first certificate for controller identity", addr.String(), host) - } - - return nil -} - -type CertValidatingIdentity struct { - identity.Identity -} - -func (self *CertValidatingIdentity) ClientTLSConfig() *tls.Config { - cfg := self.Identity.ClientTLSConfig() - cfg.VerifyConnection = self.VerifyConnection - return cfg -} - -func (self *CertValidatingIdentity) ServerTLSConfig() *tls.Config { - cfg := self.Identity.ServerTLSConfig() - cfg.VerifyConnection = self.VerifyConnection - return cfg -} - -func (self *CertValidatingIdentity) VerifyConnection(state tls.ConnectionState) error { - if len(state.PeerCertificates) == 0 { - return errors.New("no peer certificates provided") - } - log := pfxlog.Logger() - for _, cert := range state.PeerCertificates { - log.Infof("cert provided: CN: %v IsCA: %v", cert.Subject.CommonName, cert.IsCA) - } - - options := x509.VerifyOptions{ - Roots: self.Identity.CA(), - Intermediates: x509.NewCertPool(), - } - - for _, cert := range state.PeerCertificates[1:] { - options.Intermediates.AddCert(cert) - } - - result, err := state.PeerCertificates[0].Verify(options) - - if err != nil { - pfxlog.Logger().WithError(err).Error("got error validating cert") - return err - } - - log.Infof("got result: %v", result) - return nil -} diff --git a/controller/config/config.go b/controller/config/config.go index 1a4549d2f..274cef4d8 100644 --- a/controller/config/config.go +++ b/controller/config/config.go @@ -19,546 +19,922 @@ package config import ( "bytes" "crypto/sha1" + "crypto/tls" "crypto/x509" - "encoding/pem" + "encoding/json" "fmt" + "github.com/hashicorp/go-hclog" "github.com/michaelquigley/pfxlog" - nfpem "github.com/openziti/foundation/v2/pem" + "github.com/openziti/channel/v2" "github.com/openziti/identity" + "github.com/openziti/storage/boltz" + "github.com/openziti/transport/v2" + transporttls "github.com/openziti/transport/v2/tls" + "github.com/openziti/ziti/common/config" + "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/common/pb/mgmt_pb" "github.com/openziti/ziti/controller/command" + "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/router/xgress" "github.com/pkg/errors" - "net" + "gopkg.in/yaml.v2" + "math" "net/url" "os" - "reflect" - "strconv" "strings" - "sync" "time" ) const ( - DefaultEdgeApiActivityUpdateBatchSize = 250 - DefaultEdgeAPIActivityUpdateInterval = 90 * time.Second - MaxEdgeAPIActivityUpdateBatchSize = 10000 - MinEdgeAPIActivityUpdateBatchSize = 1 - MaxEdgeAPIActivityUpdateInterval = 10 * time.Minute - MinEdgeAPIActivityUpdateInterval = time.Millisecond + DefaultProfileMemoryInterval = 15 * time.Second + DefaultHealthChecksBoltCheckInterval = 30 * time.Second + DefaultHealthChecksBoltCheckTimeout = 20 * time.Second + DefaultHealthChecksBoltCheckInitialDelay = 30 * time.Second - DefaultEdgeSessionTimeout = 30 * time.Minute - MinEdgeSessionTimeout = 1 * time.Minute + DefaultRaftCommandHandlerMaxQueueSize = 1000 - MinEdgeEnrollmentDuration = 5 * time.Minute - DefaultEdgeEnrollmentDuration = 180 * time.Minute + // DefaultTlsHandshakeRateLimiterEnabled is whether the tls handshake rate limiter is enabled by default + DefaultTlsHandshakeRateLimiterEnabled = false - DefaultHttpIdleTimeout = 5000 * time.Millisecond - DefaultHttpReadTimeout = 5000 * time.Millisecond - DefaultHttpReadHeaderTimeout = 5000 * time.Millisecond - DefaultHttpWriteTimeout = 100000 * time.Millisecond + // TlsHandshakeRateLimiterMinSizeValue is the minimum size that can be configured for the tls handshake rate limiter + // window range + TlsHandshakeRateLimiterMinSizeValue = 5 - DefaultTotpDomain = "openziti.io" + // TlsHandshakeRateLimiterMaxSizeValue is the maximum size that can be configured for the tls handshake rate limiter + // window range + TlsHandshakeRateLimiterMaxSizeValue = 10000 - DefaultAuthRateLimiterEnabled = true - DefaultAuthRateLimiterMaxSize = 250 - DefaultAuthRateLimiterMinSize = 5 + // TlsHandshakeRateLimiterMetricOutstandingCount is the name of the metric tracking how many tasks are in process + TlsHandshakeRateLimiterMetricOutstandingCount = "tls_handshake_limiter.in_process" - AuthRateLimiterMinSizeValue = 5 - AuthRateLimiterMaxSizeValue = 1000 + // TlsHandshakeRateLimiterMetricCurrentWindowSize is the name of the metric tracking the current window size + TlsHandshakeRateLimiterMetricCurrentWindowSize = "tls_handshake_limiter.window_size" + + // TlsHandshakeRateLimiterMetricWorkTimer is the name of the metric tracking how long successful tasks are taking to complete + TlsHandshakeRateLimiterMetricWorkTimer = "tls_handshake_limiter.work_timer" + + // DefaultTlsHandshakeRateLimiterMaxWindow is the default max size for the tls handshake rate limiter + DefaultTlsHandshakeRateLimiterMaxWindow = 1000 ) -type Enrollment struct { - SigningCert identity.Identity - SigningCertConfig identity.Config - SigningCertCaPem []byte - EdgeIdentity EnrollmentOption - EdgeRouter EnrollmentOption +type Config struct { + Id *identity.TokenId + SpiffeIdTrustDomain *url.URL + AdditionalTrustDomains []*url.URL + + Raft *RaftConfig + Network *NetworkConfig + Edge *EdgeConfig + Db boltz.Db + Trace struct { + Handler *channel.TraceHandler + } + Profile struct { + Memory struct { + Path string + Interval time.Duration + } + CPU struct { + Path string + } + } + Ctrl struct { + Listener transport.Address + Options *CtrlOptions + } + HealthChecks struct { + BoltCheck struct { + Interval time.Duration + Timeout time.Duration + InitialDelay time.Duration + } + } + CommandRateLimiter command.RateLimiterConfig + TlsHandshakeRateLimiter command.AdaptiveRateLimiterConfig + Src map[interface{}]interface{} } -type EnrollmentOption struct { - Duration time.Duration +func (self *Config) ToJson() (string, error) { + jsonMap, err := config.ToJsonCompatibleMap(self.Src) + if err != nil { + return "", err + } + b, err := json.Marshal(jsonMap) + return string(b), err } -type Totp struct { - Hostname string +// CtrlOptions extends channel.Options to include support for additional, non-channel specific options +// (e.g. NewListener) +type CtrlOptions struct { + *channel.Options + NewListener *transport.Address + AdvertiseAddress *transport.Address + RouterHeartbeatOptions *channel.HeartbeatOptions + PeerHeartbeatOptions *channel.HeartbeatOptions } -type Api struct { - SessionTimeout time.Duration - ActivityUpdateBatchSize int - ActivityUpdateInterval time.Duration - - Listener string - Address string - IdentityCaPem []byte - HttpTimeouts HttpTimeouts +func (config *Config) Configure(sub config.Subconfig) error { + return sub.LoadConfig(config.Src) } -type Config struct { - Enabled bool - Api Api - Enrollment Enrollment - - caPems *bytes.Buffer - caPemsOnce sync.Once - Totp Totp - AuthRateLimiter command.AdaptiveRateLimiterConfig - caCerts []*x509.Certificate -} +func LoadConfig(path string) (*Config, error) { + cfgBytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } -type HttpTimeouts struct { - ReadTimeoutDuration time.Duration - ReadHeaderTimeoutDuration time.Duration - WriteTimeoutDuration time.Duration - IdleTimeoutsDuration time.Duration -} + cfgmap := make(map[interface{}]interface{}) + if err = yaml.NewDecoder(bytes.NewReader(cfgBytes)).Decode(&cfgmap); err != nil { + return nil, err + } + config.InjectEnv(cfgmap) + if value, found := cfgmap["v"]; found { + if value.(int) != 3 { + panic("config version mismatch: see docs for information on config updates") + } + } else { + panic("no config version: see docs for information on config") + } + + var identityConfig *identity.Config -func DefaultHttpTimeouts() *HttpTimeouts { - httpTimeouts := &HttpTimeouts{ - ReadTimeoutDuration: DefaultHttpReadTimeout, - ReadHeaderTimeoutDuration: DefaultHttpReadHeaderTimeout, - WriteTimeoutDuration: DefaultHttpWriteTimeout, - IdleTimeoutsDuration: DefaultHttpIdleTimeout, + if value, found := cfgmap["identity"]; found { + subMap := value.(map[interface{}]interface{}) + identityConfig, err = identity.NewConfigFromMapWithPathContext(subMap, "identity") + + if err != nil { + return nil, fmt.Errorf("could not parse root identity: %v", err) + } + + if identityConfig.ServerCert == "" && identityConfig.ServerKey == "" { + identityConfig.ServerCert = identityConfig.Cert + identityConfig.ServerKey = identityConfig.Key + } + } else { + return nil, fmt.Errorf("identity section not found") } - return httpTimeouts -} -func NewConfig() *Config { - return &Config{ - Enabled: false, - caPems: bytes.NewBuffer(nil), + controllerConfig := &Config{ + Network: DefaultNetworkConfig(), + Src: cfgmap, } -} -func (c *Config) SessionTimeoutDuration() time.Duration { - return c.Api.SessionTimeout -} + if id, err := identity.LoadIdentity(*identityConfig); err != nil { + return nil, fmt.Errorf("unable to load identity (%s)", err) + } else { + controllerConfig.Id = identity.NewIdentity(id) + + if err := controllerConfig.Id.WatchFiles(); err != nil { + pfxlog.Logger().Warn("could not enable file watching on identity: %w", err) + } + } -func (c *Config) CaPems() []byte { - c.caPemsOnce.Do(func() { - c.RefreshCas() - }) + if value, found := cfgmap["network"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if options, err := LoadNetworkConfig(submap); err == nil { + controllerConfig.Network = options + } else { + return nil, fmt.Errorf("invalid 'network' stanza (%s)", err) + } + } else { + pfxlog.Logger().Warn("invalid or empty 'network' stanza") + } + } - return c.caPems.Bytes() -} + if value, found := cfgmap["raft"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + controllerConfig.Raft = &RaftConfig{} + controllerConfig.Raft.CommandHandlerOptions.MaxQueueSize = DefaultRaftCommandHandlerMaxQueueSize -func (c *Config) CaCerts() []*x509.Certificate { - c.caPemsOnce.Do(func() { - c.RefreshCas() - }) + if value, found := submap["dataDir"]; found { + controllerConfig.Raft.DataDir = value.(string) + } else { + return nil, errors.Errorf("raft dataDir configuration missing") + } + if value, found := submap["minClusterSize"]; found { + controllerConfig.Raft.MinClusterSize = uint32(value.(int)) + } + if value, found := submap["bootstrapMembers"]; found { + if lst, ok := value.([]interface{}); ok { + for idx, val := range lst { + if member, ok := val.(string); ok { + controllerConfig.Raft.BootstrapMembers = append(controllerConfig.Raft.BootstrapMembers, member) + } else { + return nil, errors.Errorf("invalid bootstrapMembers value '%v'at index %v, should be array", idx, val) + } + } + } else { + return nil, errors.New("invalid bootstrapMembers value, should be array") + } + } - return c.caCerts -} + if value, found := submap["snapshotInterval"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.Raft.SnapshotInterval = &val + } else { + return nil, errors.Wrapf(err, "failed to parse raft.snapshotInterval value '%v", value) + } + } -// AddCaPems adds a byte array of certificates to the current buffered list of CAs. The certificates -// should be in PEM format separated by new lines. RefreshCas should be called after all -// calls to AddCaPems are completed. -func (c *Config) AddCaPems(caPems []byte) { - c.caPems.WriteString("\n") - c.caPems.Write(caPems) -} + if value, found := submap["commitTimeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.Raft.CommitTimeout = &val + } else { + return nil, errors.Wrapf(err, "failed to parse raft.commitTimeout value '%v", value) + } + } -func (c *Config) RefreshCas() { - c.caPems = CalculateCaPems(c.caPems) - c.caCerts = nfpem.PemBytesToCertificates(c.caPems.Bytes()) -} + if value, found := submap["electionTimeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.Raft.ElectionTimeout = &val + } else { + return nil, errors.Wrapf(err, "failed to parse raft.electionTimeout value '%v", value) + } + } -func (c *Config) loadTotpSection(edgeConfigMap map[any]any) error { - c.Totp = Totp{} - c.Totp.Hostname = DefaultTotpDomain + if value, found := submap["heartbeatTimeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.Raft.HeartbeatTimeout = &val + } else { + return nil, errors.Wrapf(err, "failed to parse raft.heartbeatTimeout value '%v", value) + } + } - if value, found := edgeConfigMap["totp"]; found { - if value == nil { - return nil + if value, found := submap["leaderLeaseTimeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.Raft.LeaderLeaseTimeout = &val + } else { + return nil, errors.Wrapf(err, "failed to parse raft.leaderLeaseTimeout value '%v", value) + } + } + + if value, found := submap["snapshotThreshold"]; found { + val := uint32(value.(int)) + controllerConfig.Raft.SnapshotThreshold = &val + } + + if value, found := submap["maxAppendEntries"]; found { + val := uint32(value.(int)) + controllerConfig.Raft.MaxAppendEntries = &val + } + + if value, found := submap["trailingLogs"]; found { + val := uint32(value.(int)) + controllerConfig.Raft.TrailingLogs = &val + } + + if value, found := submap["logLevel"]; found { + val := fmt.Sprintf("%v", value) + if hclog.LevelFromString(val) == hclog.NoLevel { + return nil, errors.Errorf("invalid value for raft.logLevel [%v]", val) + } + controllerConfig.Raft.LogLevel = &val + } + + if value, found := submap["logFile"]; found { + val := fmt.Sprintf("%v", value) + options := *hclog.DefaultOptions + f, err := os.Create(val) + if err != nil { + return nil, errors.Wrapf(err, "unable to open raft log file [%v]", val) + } + options.Output = f + if controllerConfig.Raft.LogLevel != nil { + options.Level = hclog.LevelFromString(*controllerConfig.Raft.LogLevel) + } + controllerConfig.Raft.Logger = hclog.New(&options) + } + + if value, found := cfgmap["commandHandler"]; found { + if chSubMap, ok := value.(map[interface{}]interface{}); ok { + if value, found := chSubMap["maxQueueSize"]; found { + controllerConfig.Raft.CommandHandlerOptions.MaxQueueSize = uint16(value.(int)) + } + } else { + return nil, errors.New("invalid commandHandler value, should be map") + } + } + } else { + return nil, errors.Errorf("invalid raft configuration") + } + } else if value, found := cfgmap["db"]; found { + str, err := db.Open(value.(string)) + if err != nil { + return nil, err } + controllerConfig.Db = str + } else { + panic("controllerConfig must provide [db] or [raft]") + } - totpMap := value.(map[interface{}]interface{}) + //SPIFFE Trust Domain + var spiffeId *url.URL + if controllerConfig.Raft != nil { + //HA setup, SPIFFE ID must come from certs + var err error + spiffeId, err = GetSpiffeIdFromIdentity(controllerConfig.Id.Identity) + if err != nil { + panic("error determining a trust domain from a SPIFFE id in the root identity for HA configuration, must have a spiffe:// URI SANs in the server certificate or along the signing CAs chain: " + err.Error()) + } + + if spiffeId == nil { + panic("unable to determine a trust domain from a SPIFFE id in the root identity for HA configuration, must have a spiffe:// URI SANs in the server certificate or along the signing CAs chain") + } + } else { + // Non-HA/legacy system, prefer SPIFFE id from certs, but fall back to configuration if necessary + spiffeId, _ = GetSpiffeIdFromIdentity(controllerConfig.Id.Identity) - if totpMap != nil { - if hostnameVal, found := totpMap["hostname"]; found { + if spiffeId == nil { + //for non HA setups allow the trust domain to come from the configuration root value `trustDomain` + if value, found := cfgmap["trustDomain"]; found { + trustDomain, ok := value.(string) - if hostnameVal == nil { - return nil + if !ok { + panic(fmt.Sprintf("could not parse [trustDomain], expected a string got [%T]", value)) } - if hostname, ok := hostnameVal.(string); ok { - testUrl := "https://" + hostname - parsedUrl, err := url.Parse(testUrl) + if trustDomain != "" { + if !strings.HasPrefix("spiffe://", trustDomain) { + trustDomain = "spiffe://" + trustDomain + } + + spiffeId, err = url.Parse(trustDomain) if err != nil { - return fmt.Errorf("could not parse URL: %w", err) + panic("could not parse [trustDomain] when used in a SPIFFE id URI [" + trustDomain + "], please make sure it is a valid URI hostname: " + err.Error()) } - if parsedUrl.Hostname() != hostname { - return fmt.Errorf("invalid hostname in [edge.totp.hostname]: %s", hostname) + if spiffeId == nil { + panic("could not parse [trustDomain] when used in a SPIFFE id URI [" + trustDomain + "]: spiffeId is nil and no error returned") } - c.Totp.Hostname = hostname - } else { - return fmt.Errorf("[edge.totp.hostname] must be a string") + if spiffeId.Scheme != "spiffe" { + panic("[trustDomain] does not have a spiffe scheme (spiffe://) has: " + spiffeId.Scheme) + } } } } + + //default a generated trust domain and spiffe id from the sha1 of the root ca + if spiffeId == nil { + spiffeId, err = generateDefaultSpiffeId(controllerConfig.Id.Identity) + + if err != nil { + panic("could not generate default trust domain: " + err.Error()) + } + + pfxlog.Logger().Warnf("this environment is using a default generated trust domain [%s], it is recommended that a trust domain is specified in configuration via URI SANs or the 'trustDomain' field", spiffeId.String()) + pfxlog.Logger().Warnf("this environment is using a default generated trust domain [%s], it is recommended that if network components have enrolled that the generated trust domain be added to the configuration field 'additionalTrustDomains' array when configuring a explicit trust domain", spiffeId.String()) + } } - return nil -} + if spiffeId == nil { + panic("unable to determine trust domain from SPIFFE id (spiffe:// URI SANs in server cert or signing CAs) or from configuration [trustDomain], controllers must have a trust domain") + } -func (c *Config) loadApiSection(edgeConfigMap map[interface{}]interface{}) error { - c.Api = Api{} - c.Api.HttpTimeouts = *DefaultHttpTimeouts() - var err error + if spiffeId.Hostname() == "" { + panic("unable to determine trust domain from SPIFFE id: hostname was empty") + } - c.Api.ActivityUpdateBatchSize = DefaultEdgeApiActivityUpdateBatchSize - c.Api.ActivityUpdateInterval = DefaultEdgeAPIActivityUpdateInterval + //only preserve trust domain + spiffeId.Path = "" + controllerConfig.SpiffeIdTrustDomain = spiffeId - if value, found := edgeConfigMap["api"]; found { - apiSubMap := value.(map[interface{}]interface{}) + if value, found := cfgmap["additionalTrustDomains"]; found { + if valArr, ok := value.([]any); ok { + var trustDomains []*url.URL + for _, trustDomain := range valArr { + if strTrustDomain, ok := trustDomain.(string); ok { - if val, ok := apiSubMap["address"]; ok { - if c.Api.Address, ok = val.(string); !ok { - return errors.Errorf("invalid type %t for [edge.api.address], must be string", val) - } + if !strings.HasPrefix("spiffe://", strTrustDomain) { + strTrustDomain = "spiffe://" + strTrustDomain + } - if c.Api.Address == "" { - return errors.Errorf("invalid type %t for [edge.api.address], must not be an empty string", val) - } + spiffeId, err = url.Parse(strTrustDomain) - if err := validateHostPortString(c.Api.Address); err != nil { - return errors.Errorf("invalid value %s for [edge.api.address]: %v", c.Api.Address, err) + if err != nil { + panic(fmt.Sprintf("invalid entry in 'additionalTrustDomains', could not be parsed as a URI: %v", trustDomain)) + } + //only preserve trust domain + spiffeId.Path = "" + + trustDomains = append(trustDomains, spiffeId) + } else { + panic(fmt.Sprintf("invalid entry in 'additionalTrustDomains' expected a string: %v", trustDomain)) + } } - } else { - return errors.New("required value [edge.api.address] is required") + + controllerConfig.AdditionalTrustDomains = trustDomains } + } - var durationValue = 0 * time.Second - if value, found := apiSubMap["sessionTimeout"]; found { - strValue := value.(string) - durationValue, err = time.ParseDuration(strValue) - if err != nil { - return errors.Errorf("error parsing [edge.api.sessionTimeout], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) + if value, found := cfgmap["trace"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["path"]; found { + handler, err := channel.NewTraceHandler(value.(string), controllerConfig.Id.Token) + if err != nil { + return nil, err + } + handler.AddDecoder(&channel.Decoder{}) + handler.AddDecoder(&ctrl_pb.Decoder{}) + handler.AddDecoder(&xgress.Decoder{}) + handler.AddDecoder(&mgmt_pb.Decoder{}) + controllerConfig.Trace.Handler = handler } } + } - if durationValue < MinEdgeSessionTimeout { - durationValue = DefaultEdgeSessionTimeout - pfxlog.Logger().Warnf("[edge.api.sessionTimeout] defaulted to %v", durationValue) + if value, found := cfgmap["profile"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["memory"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["path"]; found { + controllerConfig.Profile.Memory.Path = value.(string) + } + if value, found := submap["intervalMs"]; found { + controllerConfig.Profile.Memory.Interval = time.Duration(value.(int)) * time.Millisecond + } else { + controllerConfig.Profile.Memory.Interval = DefaultProfileMemoryInterval + } + } + } + if value, found := submap["cpu"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["path"]; found { + controllerConfig.Profile.CPU.Path = value.(string) + } + } + } } + } + + if value, found := cfgmap["ctrl"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["listener"]; found { + listener, err := transport.ParseAddress(value.(string)) + if err != nil { + return nil, err + } + controllerConfig.Ctrl.Listener = listener + } else { + panic("controllerConfig must provide [ctrl/listener]") + } + + controllerConfig.Ctrl.Options = &CtrlOptions{ + Options: channel.DefaultOptions(), + PeerHeartbeatOptions: channel.DefaultHeartbeatOptions(), + RouterHeartbeatOptions: channel.DefaultHeartbeatOptions(), + } + + if value, found := submap["options"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + options, err := channel.LoadOptions(submap) + if err != nil { + return nil, err + } + + controllerConfig.Ctrl.Options.Options = options + + if val, found := submap["newListener"]; found { + if newListener, ok := val.(string); ok { + if newListener != "" { + if addr, err := transport.ParseAddress(newListener); err == nil { + controllerConfig.Ctrl.Options.NewListener = &addr + + if err := verifyNewListenerInServerCert(controllerConfig, addr); err != nil { + return nil, err + } + + } else { + return nil, fmt.Errorf("error loading newListener for [ctrl/options] (%v)", err) + } + } + } else { + return nil, errors.New("error loading newAddress for [ctrl/options] (must be a string)") + } + } - c.Api.SessionTimeout = durationValue + if val, found := submap["advertiseAddress"]; found { + if advertiseAddr, ok := val.(string); ok { + if advertiseAddr != "" { + addr, err := transport.ParseAddress(advertiseAddr) + if err != nil { + return nil, errors.Wrapf(err, "error parsing value '%v' for [ctrl/options/advertiseAddress]", advertiseAddr) + } + controllerConfig.Ctrl.Options.AdvertiseAddress = &addr + if controllerConfig.Raft != nil { + controllerConfig.Raft.AdvertiseAddress = addr + } + } + } else { + return nil, errors.New("error loading advertiseAddress for [ctrl/options] (must be a string)") + } + } - if val, ok := apiSubMap["activityUpdateBatchSize"]; ok { - if c.Api.ActivityUpdateBatchSize, ok = val.(int); !ok { - return errors.Errorf("invalid type %v for apiSessions.activityUpdateBatchSize, must be int", reflect.TypeOf(val)) + if value, found := submap["routerHeartbeats"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + options, err := channel.LoadHeartbeatOptions(submap) + if err != nil { + return nil, err + } + controllerConfig.Ctrl.Options.RouterHeartbeatOptions = options + } + } + + if value, found := submap["peerHeartbeats"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + options, err := channel.LoadHeartbeatOptions(submap) + if err != nil { + return nil, err + } + controllerConfig.Ctrl.Options.PeerHeartbeatOptions = options + } + } + + if err := controllerConfig.Ctrl.Options.Validate(); err != nil { + return nil, fmt.Errorf("error loading channel options for [ctrl/options] (%v)", err) + } + } + } + if controllerConfig.Raft != nil && controllerConfig.Raft.AdvertiseAddress == nil { + return nil, errors.New("[ctrl/options/advertiseAddress] is required when raft is enabled") } + } else { + panic("controllerConfig [ctrl] section in unexpected format") } + } else { + panic("controllerConfig must provide [ctrl]") + } - if val, ok := apiSubMap["activityUpdateInterval"]; ok { - if strVal, ok := val.(string); !ok { - return errors.Errorf("invalid type %v for apiSessions.activityUpdateInterval, must be string duration", reflect.TypeOf(val)) - } else { - if c.Api.ActivityUpdateInterval, err = time.ParseDuration(strVal); err != nil { - return errors.Wrapf(err, "invalid value %v for apiSessions.activityUpdateInterval, must be string duration", val) + controllerConfig.HealthChecks.BoltCheck.Interval = DefaultHealthChecksBoltCheckInterval + controllerConfig.HealthChecks.BoltCheck.Timeout = DefaultHealthChecksBoltCheckTimeout + controllerConfig.HealthChecks.BoltCheck.InitialDelay = DefaultHealthChecksBoltCheckInitialDelay + + if value, found := cfgmap["healthChecks"]; found { + if healthChecksMap, ok := value.(map[interface{}]interface{}); ok { + if value, found := healthChecksMap["boltCheck"]; found { + if boltMap, ok := value.(map[interface{}]interface{}); ok { + if value, found := boltMap["interval"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.HealthChecks.BoltCheck.Interval = val + } else { + return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.interval value '%v", value) + } + } + + if value, found := boltMap["timeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.HealthChecks.BoltCheck.Timeout = val + } else { + return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.timeout value '%v", value) + } + } + + if value, found := boltMap["initialDelay"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + controllerConfig.HealthChecks.BoltCheck.InitialDelay = val + } else { + return nil, errors.Wrapf(err, "failed to parse healthChecks.bolt.initialDelay value '%v", value) + } + } + } else { + pfxlog.Logger().Warn("invalid [healthChecks.bolt] stanza") } } + } else { + pfxlog.Logger().Warn("invalid [healthChecks] stanza") } + } - if c.Api.ActivityUpdateBatchSize < MinEdgeAPIActivityUpdateBatchSize || c.Api.ActivityUpdateBatchSize > MaxEdgeAPIActivityUpdateBatchSize { - return errors.Errorf("invalid value %v for apiSessions.activityUpdateBatchSize, must be between %v and %v", c.Api.ActivityUpdateBatchSize, MinEdgeAPIActivityUpdateBatchSize, MaxEdgeAPIActivityUpdateBatchSize) - } + controllerConfig.CommandRateLimiter.Enabled = true + controllerConfig.CommandRateLimiter.QueueSize = command.DefaultLimiterSize - if c.Api.ActivityUpdateInterval < MinEdgeAPIActivityUpdateInterval || c.Api.ActivityUpdateInterval > MaxEdgeAPIActivityUpdateInterval { - return errors.Errorf("invalid value %v for apiSessions.activityUpdateInterval, must be between %vms and %vm", c.Api.ActivityUpdateInterval.String(), MinEdgeAPIActivityUpdateInterval.Milliseconds(), MaxEdgeAPIActivityUpdateInterval.Minutes()) + if value, found := cfgmap["commandRateLimiter"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if value, found := submap["enabled"]; found { + controllerConfig.CommandRateLimiter.Enabled = strings.EqualFold("true", fmt.Sprintf("%v", value)) + } + + if value, found := submap["maxQueued"]; found { + if intVal, ok := value.(int); ok { + v := int64(intVal) + if v < command.MinLimiterSize { + return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be at least %v", value, command.MinLimiterSize) + } + if v > math.MaxUint32 { + return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be at most %v", value, int64(math.MaxUint32)) + } + controllerConfig.CommandRateLimiter.QueueSize = uint32(v) + } else { + return nil, errors.Errorf("invalid value %v for commandRateLimiter, must be integer value", value) + } + } } + } - return nil + controllerConfig.TlsHandshakeRateLimiter.SetDefaults() + controllerConfig.TlsHandshakeRateLimiter.Enabled = DefaultTlsHandshakeRateLimiterEnabled + controllerConfig.TlsHandshakeRateLimiter.MaxSize = DefaultTlsHandshakeRateLimiterMaxWindow + controllerConfig.TlsHandshakeRateLimiter.QueueSizeMetric = TlsHandshakeRateLimiterMetricOutstandingCount + controllerConfig.TlsHandshakeRateLimiter.WindowSizeMetric = TlsHandshakeRateLimiterMetricCurrentWindowSize + controllerConfig.TlsHandshakeRateLimiter.WorkTimerMetric = TlsHandshakeRateLimiterMetricWorkTimer + + if value, found := cfgmap["tls"]; found { + if tlsMap, ok := value.(map[interface{}]interface{}); ok { + if value, found := tlsMap["handshakeTimeout"]; found { + if val, err := time.ParseDuration(fmt.Sprintf("%v", value)); err == nil { + transporttls.SetSharedListenerHandshakeTimeout(val) + } else { + return nil, errors.Wrapf(err, "failed to parse tls.handshakeTimeout value '%v", value) + } + } + if err = loadTlsHandshakeRateLimiterConfig(&controllerConfig.TlsHandshakeRateLimiter, tlsMap); err != nil { + return nil, err + } + } + } - } else { - return errors.New("required configuration section [edge.api] missing") + edgeConfig, err := LoadEdgeConfigFromMap(cfgmap) + if err != nil { + return nil, err } -} + controllerConfig.Edge = edgeConfig -func validateHostPortString(address string) error { - address = strings.TrimSpace(address) + return controllerConfig, nil +} - if address == "" { - return errors.New("must not be an empty string or unspecified") +// isSelfSigned checks if the given certificate is self-signed. +func isSelfSigned(cert *x509.Certificate) (bool, error) { + // Check if the Issuer and Subject fields are equal + if cert.Issuer.String() != cert.Subject.String() { + return false, nil } - host, port, err := net.SplitHostPort(address) - + // Attempt to verify the certificate's signature with its own public key + err := cert.CheckSignatureFrom(cert) if err != nil { - return errors.Errorf("could not split host and port: %v", err) + return false, err } - if host == "" { - return errors.New("host must be specified") + return true, nil +} + +func generateDefaultSpiffeId(id identity.Identity) (*url.URL, error) { + chain := id.CaPool().GetChain(id.Cert().Leaf) + + // chain is 0 or 1, no root possible + if len(chain) <= 1 { + return nil, fmt.Errorf("error generating default trust domain from root CA: no root CA detected after chain assembly from the root identity server cert and ca bundle") } - if port == "" { - return errors.New("port must be specified") + candidateRoot := chain[len(chain)-1] + + if candidateRoot == nil { + return nil, fmt.Errorf("encountered nil candidate root ca during default trust domain generation") } - if port, err := strconv.ParseInt(port, 10, 32); err != nil { - return errors.New("invalid port, must be a integer") - } else if port < 1 || port > 65535 { - return errors.New("invalid port, must 1-65535") + if !candidateRoot.IsCA { + return nil, fmt.Errorf("candidate root CA is not flagged with the x509 CA flag") } - return nil -} + if selfSigned, _ := isSelfSigned(candidateRoot); !selfSigned { + return nil, errors.New("candidate root CA is not self signed") + } -func (c *Config) loadEnrollmentSection(edgeConfigMap map[interface{}]interface{}) error { - c.Enrollment = Enrollment{} - var err error + rawHash := sha1.Sum(candidateRoot.Raw) - if value, found := edgeConfigMap["enrollment"]; found { - enrollmentSubMap := value.(map[interface{}]interface{}) + fingerprint := fmt.Sprintf("%x", rawHash) + idStr := "spiffe://" + fingerprint - if value, found := enrollmentSubMap["signingCert"]; found { - signingCertSubMap := value.(map[interface{}]interface{}) - c.Enrollment.SigningCertConfig = identity.Config{} + spiffeId, err := url.Parse(idStr) - if value, found := signingCertSubMap["cert"]; found { - c.Enrollment.SigningCertConfig.Cert = value.(string) - certPem, err := os.ReadFile(c.Enrollment.SigningCertConfig.Cert) - if err != nil { - pfxlog.Logger().WithError(err).Panic("unable to read [edge.enrollment.cert]") - } - //The signer is a valid trust anchor - _, _ = c.caPems.WriteString("\n") - _, _ = c.caPems.Write(certPem) + if err != nil { + return nil, fmt.Errorf("could not parse generated SPIFFE id [%s] as a URI: %w", idStr, err) + } - } else { - return fmt.Errorf("required configuration value [edge.enrollment.cert] is missing") - } + return spiffeId, nil +} - if value, found := signingCertSubMap["key"]; found { - c.Enrollment.SigningCertConfig.Key = value.(string) - } else { - return fmt.Errorf("required configuration value [edge.enrollment.key] is missing") - } +// GetSpiffeIdFromIdentity will search an Identity for a trust domain encoded as a spiffe:// URI SAN starting +// from the server cert and up its signing chain. Each certificate must contain 0 or 1 spiffe:// URI SAN. The first +// SPIFFE id looking up the chain back to the root CA is returned. If no SPIFFE id is encountered, nil is returned. +// Errors are returned for parsing and processing errors only. +func GetSpiffeIdFromIdentity(id identity.Identity) (*url.URL, error) { + tlsCerts := id.ServerCert() - if value, found := signingCertSubMap["ca"]; found { - c.Enrollment.SigningCertConfig.CA = value.(string) + spiffeId, err := GetSpiffeIdFromTlsCertChain(tlsCerts) - if c.Enrollment.SigningCertCaPem, err = os.ReadFile(c.Enrollment.SigningCertConfig.CA); err != nil { - return fmt.Errorf("could not read file CA file from [edge.enrollment.signingCert.ca]") - } + if err != nil { + return nil, fmt.Errorf("failed to acquire SPIFFE id from server certs: %w", err) + } - _, _ = c.caPems.WriteString("\n") - _, _ = c.caPems.Write(c.Enrollment.SigningCertCaPem) - } //not an error if the signing certificate's CA is already represented in the root [identity.ca] + if spiffeId != nil { + return spiffeId, nil + } - if c.Enrollment.SigningCert, err = identity.LoadIdentity(c.Enrollment.SigningCertConfig); err != nil { - return fmt.Errorf("error loading [edge.enrollment.signingCert]: %s", err) - } else { - if err := c.Enrollment.SigningCert.WatchFiles(); err != nil { - pfxlog.Logger().Warn("could not enable file watching on enrollment signing cert: %w", err) - } - } + if len(tlsCerts) > 0 { + chain := id.CaPool().GetChain(tlsCerts[0].Leaf) - } else { - return errors.New("required configuration section [edge.enrollment.signingCert] missing") + if len(chain) > 0 { + spiffeId, _ = GetSpiffeIdFromCertChain(chain) } + } - if value, found := enrollmentSubMap["edgeIdentity"]; found { - edgeIdentitySubMap := value.(map[interface{}]interface{}) - - edgeIdentityDuration := 0 * time.Second - if value, found := edgeIdentitySubMap["duration"]; found { - strValue := value.(string) - var err error - edgeIdentityDuration, err = time.ParseDuration(strValue) - - if err != nil { - return errors.Errorf("error parsing [edge.enrollment.edgeIdentity.duration], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) - } - } + if spiffeId == nil { + return nil, errors.Errorf("SPIFFE id not found in identity") + } - if edgeIdentityDuration < MinEdgeEnrollmentDuration { - edgeIdentityDuration = DefaultEdgeEnrollmentDuration - } + return spiffeId, nil +} - c.Enrollment.EdgeIdentity = EnrollmentOption{Duration: edgeIdentityDuration} +// GetSpiffeIdFromCertChain cycles through a slice of certificates that goes from leaf up CAs. Each certificate +// must contain 0 or 1 spiffe:// URI SAN. The first encountered SPIFFE id looking up the chain back to the root CA is returned. +// If no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. +func GetSpiffeIdFromCertChain(certs []*x509.Certificate) (*url.URL, error) { + var spiffeId *url.URL + for _, cert := range certs { + var err error + spiffeId, err = GetSpiffeIdFromCert(cert) + + if err != nil { + return nil, fmt.Errorf("failed to determine SPIFFE ID from x509 certificate chain: %w", err) + } - } else { - return errors.New("required configuration section [edge.enrollment.edgeIdentity] missing") + if spiffeId != nil { + return spiffeId, nil } + } - if value, found := enrollmentSubMap["edgeRouter"]; found { - edgeRouterSubMap := value.(map[interface{}]interface{}) + return nil, errors.New("failed to determine SPIFFE ID, no spiffe:// URI SANs found in x509 certificate chain") +} - edgeRouterDuration := 0 * time.Second - if value, found := edgeRouterSubMap["duration"]; found { - strValue := value.(string) - var err error - edgeRouterDuration, err = time.ParseDuration(strValue) +// GetSpiffeIdFromTlsCertChain will search a tls certificate chain for a trust domain encoded as a spiffe:// URI SAN. +// Each certificate must contain 0 or 1 spiffe:// URI SAN. The first SPIFFE id looking up the chain is returned. If +// no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. +func GetSpiffeIdFromTlsCertChain(tlsCerts []*tls.Certificate) (*url.URL, error) { + for _, tlsCert := range tlsCerts { + for i, rawCert := range tlsCert.Certificate { + cert, err := x509.ParseCertificate(rawCert) - if err != nil { - return errors.Errorf("error parsing [edge.enrollment.edgeRouter.duration], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) - } + if err != nil { + return nil, fmt.Errorf("failed to parse TLS cert at index [%d]: %w", i, err) } - if edgeRouterDuration < MinEdgeEnrollmentDuration { - edgeRouterDuration = DefaultEdgeEnrollmentDuration - } + spiffeId, err := GetSpiffeIdFromCert(cert) - c.Enrollment.EdgeRouter = EnrollmentOption{Duration: edgeRouterDuration} + if err != nil { + return nil, fmt.Errorf("failed to determine SPIFFE ID from TLS cert at index [%d]: %w", i, err) + } - } else { - return errors.New("required configuration section [edge.enrollment.edgeRouter] missing") + if spiffeId != nil { + return spiffeId, nil + } } - - } else { - return errors.New("required configuration section [edge.enrollment] missing") } - return nil + return nil, nil } -func (c *Config) loadAuthRateLimiterConfig(cfgmap map[interface{}]interface{}) error { - c.AuthRateLimiter.SetDefaults() +// GetSpiffeIdFromCert will search a x509 certificate for a trust domain encoded as a spiffe:// URI SAN. +// Each certificate must contain 0 or 1 spiffe:// URI SAN. The first SPIFFE id looking up the chain is returned. If +// no SPIFFE id is encountered, nil is returned. Errors are returned for parsing and processing errors only. +func GetSpiffeIdFromCert(cert *x509.Certificate) (*url.URL, error) { + var spiffeId *url.URL + for _, uriSan := range cert.URIs { + if uriSan.Scheme == "spiffe" { + if spiffeId != nil { + return nil, fmt.Errorf("multiple URI SAN spiffe:// ids encountered, must only have one, encountered at least two: [%s] and [%s]", spiffeId.String(), uriSan.String()) + } + spiffeId = uriSan + } + } - c.AuthRateLimiter.Enabled = DefaultAuthRateLimiterEnabled - c.AuthRateLimiter.MaxSize = DefaultAuthRateLimiterMaxSize - c.AuthRateLimiter.MinSize = DefaultAuthRateLimiterMinSize + return spiffeId, nil +} - if value, found := cfgmap["authRateLimiter"]; found { +func loadTlsHandshakeRateLimiterConfig(rateLimitConfig *command.AdaptiveRateLimiterConfig, cfgmap map[interface{}]interface{}) error { + if value, found := cfgmap["rateLimiter"]; found { if submap, ok := value.(map[interface{}]interface{}); ok { - if err := command.LoadAdaptiveRateLimiterConfig(&c.AuthRateLimiter, submap); err != nil { + if err := command.LoadAdaptiveRateLimiterConfig(rateLimitConfig, submap); err != nil { return err } - if c.AuthRateLimiter.MaxSize < AuthRateLimiterMinSizeValue { - return errors.Errorf("invalid value %v for authRateLimiter.maxSize, must be at least %v", - c.AuthRateLimiter.MaxSize, AuthRateLimiterMinSizeValue) + if rateLimitConfig.MaxSize < TlsHandshakeRateLimiterMinSizeValue { + return errors.Errorf("invalid value %v for tls.rateLimiter.maxSize, must be at least %v", + rateLimitConfig.MaxSize, TlsHandshakeRateLimiterMinSizeValue) } - if c.AuthRateLimiter.MaxSize > AuthRateLimiterMaxSizeValue { - return errors.Errorf("invalid value %v for authRateLimiter.maxSize, must be at most %v", - c.AuthRateLimiter.MaxSize, AuthRateLimiterMaxSizeValue) + if rateLimitConfig.MaxSize > TlsHandshakeRateLimiterMaxSizeValue { + return errors.Errorf("invalid value %v for tls.rateLimiter.maxSize, must be at most %v", + rateLimitConfig.MaxSize, TlsHandshakeRateLimiterMaxSizeValue) } - if c.AuthRateLimiter.MinSize < AuthRateLimiterMinSizeValue { - return errors.Errorf("invalid value %v for authRateLimiter.minSize, must be at least %v", - c.AuthRateLimiter.MinSize, AuthRateLimiterMinSizeValue) + if rateLimitConfig.MinSize < TlsHandshakeRateLimiterMinSizeValue { + return errors.Errorf("invalid value %v for tls.rateLimiter.minSize, must be at least %v", + rateLimitConfig.MinSize, TlsHandshakeRateLimiterMinSizeValue) } - if c.AuthRateLimiter.MinSize > AuthRateLimiterMaxSizeValue { - return errors.Errorf("invalid value %v for authRateLimiter.minSize, must be at most %v", - c.AuthRateLimiter.MinSize, AuthRateLimiterMaxSizeValue) + if rateLimitConfig.MinSize > TlsHandshakeRateLimiterMaxSizeValue { + return errors.Errorf("invalid value %v for tls.rateLimiter.minSize, must be at most %v", + rateLimitConfig.MinSize, TlsHandshakeRateLimiterMaxSizeValue) } } else { - return errors.Errorf("invalid type for authRateLimiter, should be map instead of %T", value) + return errors.Errorf("invalid type for tls.rateLimiter, should be map instead of %T", value) } } return nil } -func LoadFromMap(configMap map[interface{}]interface{}) (*Config, error) { - edgeConfig := NewConfig() - - var edgeConfigMap map[interface{}]interface{} - - if val, ok := configMap["edge"]; ok && val != nil { - if edgeConfigMap, ok = val.(map[interface{}]interface{}); !ok { - return nil, fmt.Errorf("expected map as edge configuration") - } - } else { - return edgeConfig, nil +// verifyNewListenerInServerCert verifies that the hostname (ip/dns) for addr is present as an IP/DNS SAN in the first +// certificate provided in the controller's identity server certificates. This is to avoid scenarios where +// newListener propagated to routers who will never be able to verify the controller's certificates due to SAN issues. +func verifyNewListenerInServerCert(controllerConfig *Config, addr transport.Address) error { + addrSplits := strings.Split(addr.String(), ":") + if len(addrSplits) < 3 { + return errors.New("could not determine newListener's host value, expected at least three segments") } - edgeConfig.Enabled = configMap != nil + host := addrSplits[1] - if !edgeConfig.Enabled { - return edgeConfig, nil + serverCerts := controllerConfig.Id.Identity.ServerCert() + + if len(serverCerts) == 0 { + return errors.New("could not verify newListener value, server certificate for identity contains no certificates") } - var err error + hostFound := false + for _, serverCert := range serverCerts { + for _, dnsName := range serverCert.Leaf.DNSNames { + if dnsName == host { + hostFound = true + break + } + } - if err = edgeConfig.loadApiSection(edgeConfigMap); err != nil { - return nil, err - } + if hostFound { + break + } - if err = edgeConfig.loadTotpSection(edgeConfigMap); err != nil { - return nil, err - } + if !hostFound { + for _, ipAddresses := range serverCert.Leaf.IPAddresses { + if host == ipAddresses.String() { + hostFound = true + break + } + } + } - if err = edgeConfig.loadEnrollmentSection(edgeConfigMap); err != nil { - return nil, err + if hostFound { + break + } } - if err = edgeConfig.loadAuthRateLimiterConfig(edgeConfigMap); err != nil { - return nil, err + if !hostFound { + return fmt.Errorf("could not find newListener [%s] host value [%s] in first certificate for controller identity", addr.String(), host) } - return edgeConfig, nil + return nil } -// CalculateCaPems takes the supplied caPems buffer as a set of PEM Certificates separated by new lines. Duplicate -// certificates are removed, and the result is returned as a bytes.Buffer of PEM Certificates separated by new lines. -func CalculateCaPems(caPems *bytes.Buffer) *bytes.Buffer { - caPemMap := map[string][]byte{} - - newCaPems := bytes.Buffer{} - blocksToProcess := caPems.Bytes() - - for len(blocksToProcess) != 0 { - var block *pem.Block - block, blocksToProcess = pem.Decode(blocksToProcess) +type CertValidatingIdentity struct { + identity.Identity +} - if block != nil { +func (self *CertValidatingIdentity) ClientTLSConfig() *tls.Config { + cfg := self.Identity.ClientTLSConfig() + cfg.VerifyConnection = self.VerifyConnection + return cfg +} - if block.Type != "CERTIFICATE" { - pfxlog.Logger(). - WithField("type", block.Type). - WithField("block", string(pem.EncodeToMemory(block))). - Warn("encountered an invalid PEM block type loading configured CAs, block will be ignored") - continue - } +func (self *CertValidatingIdentity) ServerTLSConfig() *tls.Config { + cfg := self.Identity.ServerTLSConfig() + cfg.VerifyConnection = self.VerifyConnection + return cfg +} - cert, err := x509.ParseCertificate(block.Bytes) +func (self *CertValidatingIdentity) VerifyConnection(state tls.ConnectionState) error { + if len(state.PeerCertificates) == 0 { + return errors.New("no peer certificates provided") + } + log := pfxlog.Logger() + for _, cert := range state.PeerCertificates { + log.Infof("cert provided: CN: %v IsCA: %v", cert.Subject.CommonName, cert.IsCA) + } - if err != nil { - pfxlog.Logger(). - WithField("type", block.Type). - WithField("block", string(pem.EncodeToMemory(block))). - WithError(err). - Warn("block could not be parsed as a certificate, block will be ignored") - continue - } - - if !cert.IsCA { - pfxlog.Logger(). - WithField("type", block.Type). - WithField("block", string(pem.EncodeToMemory(block))). - Warn("block is not a CA, block will be ignored") - continue - } - // #nosec - hash := sha1.Sum(block.Bytes) - fingerprint := toHex(hash[:]) - newPem := pem.EncodeToMemory(block) - caPemMap[fingerprint] = newPem - } else { - blocksToProcess = nil - } + options := x509.VerifyOptions{ + Roots: self.Identity.CA(), + Intermediates: x509.NewCertPool(), } - for _, caPem := range caPemMap { - _, _ = newCaPems.WriteString("\n") - _, _ = newCaPems.Write(caPem) + for _, cert := range state.PeerCertificates[1:] { + options.Intermediates.AddCert(cert) } - return &newCaPems -} + result, err := state.PeerCertificates[0].Verify(options) -// toHex takes a byte array returns a hex formatted fingerprint -func toHex(data []byte) string { - var buf bytes.Buffer - for i, b := range data { - if i > 0 { - _, _ = fmt.Fprintf(&buf, ":") - } - _, _ = fmt.Fprintf(&buf, "%02x", b) + if err != nil { + pfxlog.Logger().WithError(err).Error("got error validating cert") + return err } - return strings.ToUpper(buf.String()) + + log.Infof("got result: %v", result) + return nil } diff --git a/controller/config/config_edge.go b/controller/config/config_edge.go new file mode 100644 index 000000000..dadb63cea --- /dev/null +++ b/controller/config/config_edge.go @@ -0,0 +1,564 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package config + +import ( + "bytes" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "fmt" + "github.com/michaelquigley/pfxlog" + nfpem "github.com/openziti/foundation/v2/pem" + "github.com/openziti/identity" + "github.com/openziti/ziti/controller/command" + "github.com/pkg/errors" + "net" + "net/url" + "os" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +const ( + DefaultEdgeApiActivityUpdateBatchSize = 250 + DefaultEdgeAPIActivityUpdateInterval = 90 * time.Second + MaxEdgeAPIActivityUpdateBatchSize = 10000 + MinEdgeAPIActivityUpdateBatchSize = 1 + MaxEdgeAPIActivityUpdateInterval = 10 * time.Minute + MinEdgeAPIActivityUpdateInterval = time.Millisecond + + DefaultEdgeSessionTimeout = 30 * time.Minute + MinEdgeSessionTimeout = 1 * time.Minute + + MinEdgeEnrollmentDuration = 5 * time.Minute + DefaultEdgeEnrollmentDuration = 180 * time.Minute + + DefaultHttpIdleTimeout = 5000 * time.Millisecond + DefaultHttpReadTimeout = 5000 * time.Millisecond + DefaultHttpReadHeaderTimeout = 5000 * time.Millisecond + DefaultHttpWriteTimeout = 100000 * time.Millisecond + + DefaultTotpDomain = "openziti.io" + + DefaultAuthRateLimiterEnabled = true + DefaultAuthRateLimiterMaxSize = 250 + DefaultAuthRateLimiterMinSize = 5 + + AuthRateLimiterMinSizeValue = 5 + AuthRateLimiterMaxSizeValue = 1000 +) + +type Enrollment struct { + SigningCert identity.Identity + SigningCertConfig identity.Config + SigningCertCaPem []byte + EdgeIdentity EnrollmentOption + EdgeRouter EnrollmentOption +} + +type EnrollmentOption struct { + Duration time.Duration +} + +type Totp struct { + Hostname string +} + +type Api struct { + SessionTimeout time.Duration + ActivityUpdateBatchSize int + ActivityUpdateInterval time.Duration + + Listener string + Address string + IdentityCaPem []byte + HttpTimeouts HttpTimeouts +} + +type EdgeConfig struct { + Enabled bool + Api Api + Enrollment Enrollment + + caPems *bytes.Buffer + caPemsOnce sync.Once + Totp Totp + AuthRateLimiter command.AdaptiveRateLimiterConfig + caCerts []*x509.Certificate +} + +type HttpTimeouts struct { + ReadTimeoutDuration time.Duration + ReadHeaderTimeoutDuration time.Duration + WriteTimeoutDuration time.Duration + IdleTimeoutsDuration time.Duration +} + +func DefaultHttpTimeouts() *HttpTimeouts { + httpTimeouts := &HttpTimeouts{ + ReadTimeoutDuration: DefaultHttpReadTimeout, + ReadHeaderTimeoutDuration: DefaultHttpReadHeaderTimeout, + WriteTimeoutDuration: DefaultHttpWriteTimeout, + IdleTimeoutsDuration: DefaultHttpIdleTimeout, + } + return httpTimeouts +} + +func NewEdgeConfig() *EdgeConfig { + return &EdgeConfig{ + Enabled: false, + caPems: bytes.NewBuffer(nil), + } +} + +func (c *EdgeConfig) SessionTimeoutDuration() time.Duration { + return c.Api.SessionTimeout +} + +func (c *EdgeConfig) CaPems() []byte { + c.caPemsOnce.Do(func() { + c.RefreshCas() + }) + + return c.caPems.Bytes() +} + +func (c *EdgeConfig) CaCerts() []*x509.Certificate { + c.caPemsOnce.Do(func() { + c.RefreshCas() + }) + + return c.caCerts +} + +// AddCaPems adds a byte array of certificates to the current buffered list of CAs. The certificates +// should be in PEM format separated by new lines. RefreshCas should be called after all +// calls to AddCaPems are completed. +func (c *EdgeConfig) AddCaPems(caPems []byte) { + c.caPems.WriteString("\n") + c.caPems.Write(caPems) +} + +func (c *EdgeConfig) RefreshCas() { + c.caPems = CalculateCaPems(c.caPems) + c.caCerts = nfpem.PemBytesToCertificates(c.caPems.Bytes()) +} + +func (c *EdgeConfig) loadTotpSection(edgeConfigMap map[any]any) error { + c.Totp = Totp{} + c.Totp.Hostname = DefaultTotpDomain + + if value, found := edgeConfigMap["totp"]; found { + if value == nil { + return nil + } + + totpMap := value.(map[interface{}]interface{}) + + if totpMap != nil { + if hostnameVal, found := totpMap["hostname"]; found { + + if hostnameVal == nil { + return nil + } + + if hostname, ok := hostnameVal.(string); ok { + testUrl := "https://" + hostname + parsedUrl, err := url.Parse(testUrl) + + if err != nil { + return fmt.Errorf("could not parse URL: %w", err) + } + + if parsedUrl.Hostname() != hostname { + return fmt.Errorf("invalid hostname in [edge.totp.hostname]: %s", hostname) + } + + c.Totp.Hostname = hostname + } else { + return fmt.Errorf("[edge.totp.hostname] must be a string") + } + } + } + } + + return nil +} + +func (c *EdgeConfig) loadApiSection(edgeConfigMap map[interface{}]interface{}) error { + c.Api = Api{} + c.Api.HttpTimeouts = *DefaultHttpTimeouts() + var err error + + c.Api.ActivityUpdateBatchSize = DefaultEdgeApiActivityUpdateBatchSize + c.Api.ActivityUpdateInterval = DefaultEdgeAPIActivityUpdateInterval + + if value, found := edgeConfigMap["api"]; found { + apiSubMap := value.(map[interface{}]interface{}) + + if val, ok := apiSubMap["address"]; ok { + if c.Api.Address, ok = val.(string); !ok { + return errors.Errorf("invalid type %t for [edge.api.address], must be string", val) + } + + if c.Api.Address == "" { + return errors.Errorf("invalid type %t for [edge.api.address], must not be an empty string", val) + } + + if err := validateHostPortString(c.Api.Address); err != nil { + return errors.Errorf("invalid value %s for [edge.api.address]: %v", c.Api.Address, err) + } + } else { + return errors.New("required value [edge.api.address] is required") + } + + var durationValue = 0 * time.Second + if value, found := apiSubMap["sessionTimeout"]; found { + strValue := value.(string) + durationValue, err = time.ParseDuration(strValue) + if err != nil { + return errors.Errorf("error parsing [edge.api.sessionTimeout], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) + } + } + + if durationValue < MinEdgeSessionTimeout { + durationValue = DefaultEdgeSessionTimeout + pfxlog.Logger().Warnf("[edge.api.sessionTimeout] defaulted to %v", durationValue) + } + + c.Api.SessionTimeout = durationValue + + if val, ok := apiSubMap["activityUpdateBatchSize"]; ok { + if c.Api.ActivityUpdateBatchSize, ok = val.(int); !ok { + return errors.Errorf("invalid type %v for apiSessions.activityUpdateBatchSize, must be int", reflect.TypeOf(val)) + } + } + + if val, ok := apiSubMap["activityUpdateInterval"]; ok { + if strVal, ok := val.(string); !ok { + return errors.Errorf("invalid type %v for apiSessions.activityUpdateInterval, must be string duration", reflect.TypeOf(val)) + } else { + if c.Api.ActivityUpdateInterval, err = time.ParseDuration(strVal); err != nil { + return errors.Wrapf(err, "invalid value %v for apiSessions.activityUpdateInterval, must be string duration", val) + } + } + } + + if c.Api.ActivityUpdateBatchSize < MinEdgeAPIActivityUpdateBatchSize || c.Api.ActivityUpdateBatchSize > MaxEdgeAPIActivityUpdateBatchSize { + return errors.Errorf("invalid value %v for apiSessions.activityUpdateBatchSize, must be between %v and %v", c.Api.ActivityUpdateBatchSize, MinEdgeAPIActivityUpdateBatchSize, MaxEdgeAPIActivityUpdateBatchSize) + } + + if c.Api.ActivityUpdateInterval < MinEdgeAPIActivityUpdateInterval || c.Api.ActivityUpdateInterval > MaxEdgeAPIActivityUpdateInterval { + return errors.Errorf("invalid value %v for apiSessions.activityUpdateInterval, must be between %vms and %vm", c.Api.ActivityUpdateInterval.String(), MinEdgeAPIActivityUpdateInterval.Milliseconds(), MaxEdgeAPIActivityUpdateInterval.Minutes()) + } + + return nil + + } else { + return errors.New("required configuration section [edge.api] missing") + } +} + +func validateHostPortString(address string) error { + address = strings.TrimSpace(address) + + if address == "" { + return errors.New("must not be an empty string or unspecified") + } + + host, port, err := net.SplitHostPort(address) + + if err != nil { + return errors.Errorf("could not split host and port: %v", err) + } + + if host == "" { + return errors.New("host must be specified") + } + + if port == "" { + return errors.New("port must be specified") + } + + if port, err := strconv.ParseInt(port, 10, 32); err != nil { + return errors.New("invalid port, must be a integer") + } else if port < 1 || port > 65535 { + return errors.New("invalid port, must 1-65535") + } + + return nil +} + +func (c *EdgeConfig) loadEnrollmentSection(edgeConfigMap map[interface{}]interface{}) error { + c.Enrollment = Enrollment{} + var err error + + if value, found := edgeConfigMap["enrollment"]; found { + enrollmentSubMap := value.(map[interface{}]interface{}) + + if value, found := enrollmentSubMap["signingCert"]; found { + signingCertSubMap := value.(map[interface{}]interface{}) + c.Enrollment.SigningCertConfig = identity.Config{} + + if value, found := signingCertSubMap["cert"]; found { + c.Enrollment.SigningCertConfig.Cert = value.(string) + certPem, err := os.ReadFile(c.Enrollment.SigningCertConfig.Cert) + if err != nil { + pfxlog.Logger().WithError(err).Panic("unable to read [edge.enrollment.cert]") + } + //The signer is a valid trust anchor + _, _ = c.caPems.WriteString("\n") + _, _ = c.caPems.Write(certPem) + + } else { + return fmt.Errorf("required configuration value [edge.enrollment.cert] is missing") + } + + if value, found := signingCertSubMap["key"]; found { + c.Enrollment.SigningCertConfig.Key = value.(string) + } else { + return fmt.Errorf("required configuration value [edge.enrollment.key] is missing") + } + + if value, found := signingCertSubMap["ca"]; found { + c.Enrollment.SigningCertConfig.CA = value.(string) + + if c.Enrollment.SigningCertCaPem, err = os.ReadFile(c.Enrollment.SigningCertConfig.CA); err != nil { + return fmt.Errorf("could not read file CA file from [edge.enrollment.signingCert.ca]") + } + + _, _ = c.caPems.WriteString("\n") + _, _ = c.caPems.Write(c.Enrollment.SigningCertCaPem) + } //not an error if the signing certificate's CA is already represented in the root [identity.ca] + + if c.Enrollment.SigningCert, err = identity.LoadIdentity(c.Enrollment.SigningCertConfig); err != nil { + return fmt.Errorf("error loading [edge.enrollment.signingCert]: %s", err) + } else { + if err := c.Enrollment.SigningCert.WatchFiles(); err != nil { + pfxlog.Logger().Warn("could not enable file watching on enrollment signing cert: %w", err) + } + } + + } else { + return errors.New("required configuration section [edge.enrollment.signingCert] missing") + } + + if value, found := enrollmentSubMap["edgeIdentity"]; found { + edgeIdentitySubMap := value.(map[interface{}]interface{}) + + edgeIdentityDuration := 0 * time.Second + if value, found := edgeIdentitySubMap["duration"]; found { + strValue := value.(string) + var err error + edgeIdentityDuration, err = time.ParseDuration(strValue) + + if err != nil { + return errors.Errorf("error parsing [edge.enrollment.edgeIdentity.duration], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) + } + } + + if edgeIdentityDuration < MinEdgeEnrollmentDuration { + edgeIdentityDuration = DefaultEdgeEnrollmentDuration + } + + c.Enrollment.EdgeIdentity = EnrollmentOption{Duration: edgeIdentityDuration} + + } else { + return errors.New("required configuration section [edge.enrollment.edgeIdentity] missing") + } + + if value, found := enrollmentSubMap["edgeRouter"]; found { + edgeRouterSubMap := value.(map[interface{}]interface{}) + + edgeRouterDuration := 0 * time.Second + if value, found := edgeRouterSubMap["duration"]; found { + strValue := value.(string) + var err error + edgeRouterDuration, err = time.ParseDuration(strValue) + + if err != nil { + return errors.Errorf("error parsing [edge.enrollment.edgeRouter.duration], invalid duration string %s, cannot parse as duration (e.g. 1m): %v", strValue, err) + } + } + + if edgeRouterDuration < MinEdgeEnrollmentDuration { + edgeRouterDuration = DefaultEdgeEnrollmentDuration + } + + c.Enrollment.EdgeRouter = EnrollmentOption{Duration: edgeRouterDuration} + + } else { + return errors.New("required configuration section [edge.enrollment.edgeRouter] missing") + } + + } else { + return errors.New("required configuration section [edge.enrollment] missing") + } + + return nil +} + +func (c *EdgeConfig) loadAuthRateLimiterConfig(cfgmap map[interface{}]interface{}) error { + c.AuthRateLimiter.SetDefaults() + + c.AuthRateLimiter.Enabled = DefaultAuthRateLimiterEnabled + c.AuthRateLimiter.MaxSize = DefaultAuthRateLimiterMaxSize + c.AuthRateLimiter.MinSize = DefaultAuthRateLimiterMinSize + + if value, found := cfgmap["authRateLimiter"]; found { + if submap, ok := value.(map[interface{}]interface{}); ok { + if err := command.LoadAdaptiveRateLimiterConfig(&c.AuthRateLimiter, submap); err != nil { + return err + } + if c.AuthRateLimiter.MaxSize < AuthRateLimiterMinSizeValue { + return errors.Errorf("invalid value %v for authRateLimiter.maxSize, must be at least %v", + c.AuthRateLimiter.MaxSize, AuthRateLimiterMinSizeValue) + } + if c.AuthRateLimiter.MaxSize > AuthRateLimiterMaxSizeValue { + return errors.Errorf("invalid value %v for authRateLimiter.maxSize, must be at most %v", + c.AuthRateLimiter.MaxSize, AuthRateLimiterMaxSizeValue) + } + + if c.AuthRateLimiter.MinSize < AuthRateLimiterMinSizeValue { + return errors.Errorf("invalid value %v for authRateLimiter.minSize, must be at least %v", + c.AuthRateLimiter.MinSize, AuthRateLimiterMinSizeValue) + } + if c.AuthRateLimiter.MinSize > AuthRateLimiterMaxSizeValue { + return errors.Errorf("invalid value %v for authRateLimiter.minSize, must be at most %v", + c.AuthRateLimiter.MinSize, AuthRateLimiterMaxSizeValue) + } + } else { + return errors.Errorf("invalid type for authRateLimiter, should be map instead of %T", value) + } + } + + return nil +} + +func LoadEdgeConfigFromMap(configMap map[interface{}]interface{}) (*EdgeConfig, error) { + edgeConfig := NewEdgeConfig() + + var edgeConfigMap map[interface{}]interface{} + + if val, ok := configMap["edge"]; ok && val != nil { + if edgeConfigMap, ok = val.(map[interface{}]interface{}); !ok { + return nil, fmt.Errorf("expected map as edge configuration") + } + } else { + return edgeConfig, nil + } + + edgeConfig.Enabled = configMap != nil + + if !edgeConfig.Enabled { + return edgeConfig, nil + } + + var err error + + if err = edgeConfig.loadApiSection(edgeConfigMap); err != nil { + return nil, err + } + + if err = edgeConfig.loadTotpSection(edgeConfigMap); err != nil { + return nil, err + } + + if err = edgeConfig.loadEnrollmentSection(edgeConfigMap); err != nil { + return nil, err + } + + if err = edgeConfig.loadAuthRateLimiterConfig(edgeConfigMap); err != nil { + return nil, err + } + + return edgeConfig, nil +} + +// CalculateCaPems takes the supplied caPems buffer as a set of PEM Certificates separated by new lines. Duplicate +// certificates are removed, and the result is returned as a bytes.Buffer of PEM Certificates separated by new lines. +func CalculateCaPems(caPems *bytes.Buffer) *bytes.Buffer { + caPemMap := map[string][]byte{} + + newCaPems := bytes.Buffer{} + blocksToProcess := caPems.Bytes() + + for len(blocksToProcess) != 0 { + var block *pem.Block + block, blocksToProcess = pem.Decode(blocksToProcess) + + if block != nil { + + if block.Type != "CERTIFICATE" { + pfxlog.Logger(). + WithField("type", block.Type). + WithField("block", string(pem.EncodeToMemory(block))). + Warn("encountered an invalid PEM block type loading configured CAs, block will be ignored") + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + + if err != nil { + pfxlog.Logger(). + WithField("type", block.Type). + WithField("block", string(pem.EncodeToMemory(block))). + WithError(err). + Warn("block could not be parsed as a certificate, block will be ignored") + continue + } + + if !cert.IsCA { + pfxlog.Logger(). + WithField("type", block.Type). + WithField("block", string(pem.EncodeToMemory(block))). + Warn("block is not a CA, block will be ignored") + continue + } + // #nosec + hash := sha1.Sum(block.Bytes) + fingerprint := toHex(hash[:]) + newPem := pem.EncodeToMemory(block) + caPemMap[fingerprint] = newPem + } else { + blocksToProcess = nil + } + } + + for _, caPem := range caPemMap { + _, _ = newCaPems.WriteString("\n") + _, _ = newCaPems.Write(caPem) + } + + return &newCaPems +} + +// toHex takes a byte array returns a hex formatted fingerprint +func toHex(data []byte) string { + var buf bytes.Buffer + for i, b := range data { + if i > 0 { + _, _ = fmt.Fprintf(&buf, ":") + } + _, _ = fmt.Fprintf(&buf, "%02x", b) + } + return strings.ToUpper(buf.String()) +} diff --git a/controller/config/config_test.go b/controller/config/config_edge_test.go similarity index 100% rename from controller/config/config_test.go rename to controller/config/config_edge_test.go diff --git a/controller/network/options.go b/controller/config/config_network.go similarity index 97% rename from controller/network/options.go rename to controller/config/config_network.go index 791d4f1b1..abdf017aa 100644 --- a/controller/network/options.go +++ b/controller/config/config_network.go @@ -14,7 +14,7 @@ limitations under the License. */ -package network +package config import ( "github.com/pkg/errors" @@ -45,7 +45,7 @@ const ( OptionsRouterCommMaxWorkers = 10_000 ) -type Options struct { +type NetworkConfig struct { CreateCircuitRetries uint32 CycleSeconds uint32 EnableLegacyLinkMgmt bool @@ -67,8 +67,8 @@ type Options struct { } } -func DefaultOptions() *Options { - options := &Options{ +func DefaultNetworkConfig() *NetworkConfig { + options := &NetworkConfig{ CreateCircuitRetries: DefaultOptionsCreateCircuitRetries, CycleSeconds: DefaultOptionsCycleSeconds, EnableLegacyLinkMgmt: DefaultOptionsEnableLegacyLinkMgmt, @@ -98,8 +98,8 @@ func DefaultOptions() *Options { return options } -func LoadOptions(src map[interface{}]interface{}) (*Options, error) { - options := DefaultOptions() +func LoadNetworkConfig(src map[interface{}]interface{}) (*NetworkConfig, error) { + options := DefaultNetworkConfig() if value, found := src["cycleSeconds"]; found { if cycleSeconds, ok := value.(int); ok { diff --git a/controller/config/config_raft.go b/controller/config/config_raft.go new file mode 100644 index 000000000..b9064f397 --- /dev/null +++ b/controller/config/config_raft.go @@ -0,0 +1,31 @@ +package config + +import ( + "github.com/hashicorp/go-hclog" + "github.com/openziti/transport/v2" + "time" +) + +type RaftConfig struct { + Recover bool + DataDir string + MinClusterSize uint32 + AdvertiseAddress transport.Address + BootstrapMembers []string + CommandHandlerOptions struct { + MaxQueueSize uint16 + } + + SnapshotInterval *time.Duration + SnapshotThreshold *uint32 + TrailingLogs *uint32 + MaxAppendEntries *uint32 + + ElectionTimeout *time.Duration + CommitTimeout *time.Duration + HeartbeatTimeout *time.Duration + LeaderLeaseTimeout *time.Duration + + LogLevel *string + Logger hclog.Logger +} diff --git a/controller/controller.go b/controller/controller.go index dd62ba50f..82131540d 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -25,7 +25,8 @@ import ( "github.com/openziti/transport/v2" "github.com/openziti/transport/v2/tls" "github.com/openziti/ziti/common/capabilities" - "github.com/openziti/ziti/common/config" + "github.com/openziti/ziti/controller/config" + "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/event" "github.com/openziti/ziti/controller/events" "github.com/openziti/ziti/controller/handler_peer_ctrl" @@ -68,7 +69,8 @@ import ( ) type Controller struct { - config *Config + config *config.Config + env *env.AppEnv network *network.Network raftController *raft.Controller localDispatcher *command.LocalDispatcher @@ -94,10 +96,6 @@ type Controller struct { apiDataOnce sync.Once } -func (c *Controller) GetConfig() *Config { - return c.config -} - func (c *Controller) GetPeerSigners() []*x509.Certificate { if c.raftController == nil || c.raftController.Mesh == nil { return nil @@ -131,11 +129,15 @@ func (c *Controller) GetId() *identity.TokenId { return c.config.Id } +func (c *Controller) GetConfig() *config.Config { + return c.config +} + func (c *Controller) GetMetricsRegistry() metrics.Registry { return c.metricsRegistry } -func (c *Controller) GetOptions() *network.Options { +func (c *Controller) GetOptions() *config.NetworkConfig { return c.config.Network } @@ -186,7 +188,7 @@ func (c *Controller) GetCloseNotify() <-chan struct{} { return c.shutdownC } -func (c *Controller) GetRaftConfig() *raft.Config { +func (c *Controller) GetRaftConfig() *config.RaftConfig { return c.config.Raft } @@ -195,15 +197,14 @@ func (c *Controller) GetCommandRateLimiterConfig() command.RateLimiterConfig { } func (c *Controller) RenderJsonConfig() (string, error) { - jsonMap, err := config.ToJsonCompatibleMap(c.config.src) - if err != nil { - return "", err - } - b, err := json.Marshal(jsonMap) - return string(b), err + return c.config.ToJson() } -func NewController(cfg *Config, versionProvider versions.VersionProvider) (*Controller, error) { +func (c *Controller) GetEnv() *env.AppEnv { + return c.env +} + +func NewController(cfg *config.Config, versionProvider versions.VersionProvider) (*Controller, error) { metricRegistry := metrics.NewRegistry(cfg.Id.Token, nil) shutdownC := make(chan struct{}) @@ -233,7 +234,14 @@ func NewController(cfg *Config, versionProvider versions.VersionProvider) (*Cont c.registerXts() - if n, err := network.NewNetwork(c); err == nil { + appEnv, err := env.NewAppEnv(c) + if err != nil { + return nil, err + } + + c.env = appEnv + + if n, err := network.NewNetwork(c, appEnv); err == nil { c.network = n } else { return nil, err @@ -380,7 +388,7 @@ func (c *Controller) Run() error { func (c *Controller) getEventHandlerConfigs() []*events.EventHandlerConfig { var result []*events.EventHandlerConfig - if e, ok := c.config.src["events"]; ok { + if e, ok := c.config.Src["events"]; ok { if em, ok := e.(map[interface{}]interface{}); ok { for id, v := range em { if config, ok := v.(map[interface{}]interface{}); ok { @@ -522,7 +530,7 @@ func (c *Controller) routerDispatchCallback(evt *event.ClusterEvent) { } func (c *Controller) getMigrationDb() (*string, error) { - val, found := c.config.src["db"] + val, found := c.config.Src["db"] if !found { return nil, nil } diff --git a/controller/env/appenv.go b/controller/env/appenv.go index 42f7aa4c7..f040aa4c2 100644 --- a/controller/env/appenv.go +++ b/controller/env/appenv.go @@ -49,7 +49,6 @@ import ( "github.com/openziti/ziti/common" "github.com/openziti/ziti/common/cert" "github.com/openziti/ziti/common/eid" - "github.com/openziti/ziti/controller" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/config" @@ -84,8 +83,8 @@ const ( ) type AppEnv struct { + Stores *db.Stores Managers *model.Managers - Config *config.Config Versions *ziti.Versions @@ -228,7 +227,7 @@ func (ae *AppEnv) ValidateServiceAccessToken(token string, apiSessionId *string) return nil, err } - if revocation != nil && tokenRevocation.CreatedAt.After(serviceAccessClaims.IssuedAt.Time) { + if revocation != nil && revocation.CreatedAt.After(serviceAccessClaims.IssuedAt.Time) { return nil, errors.New("service access token has been revoked by identity") } @@ -251,7 +250,7 @@ func (ae *AppEnv) GetApiClientCsrSigner() cert.Signer { return ae.ApiClientCsrSigner } -func (ae *AppEnv) GetHostController() model.HostController { +func (ae *AppEnv) GetHostController() HostController { return ae.HostController } @@ -260,19 +259,19 @@ func (ae *AppEnv) GetManagers() *model.Managers { } func (ae *AppEnv) GetConfig() *config.Config { - return ae.Config + return ae.HostController.GetConfig() } func (ae *AppEnv) GetServerJwtSigner() jwtsigner.Signer { return ae.serverSigner } -func (ae *AppEnv) GetDbProvider() network.DbProvider { - return ae.HostController.GetNetwork() +func (ae *AppEnv) GetDb() boltz.Db { + return ae.HostController.GetDb() } func (ae *AppEnv) GetStores() *db.Stores { - return ae.HostController.GetNetwork().GetStores() + return ae.Stores } func (ae *AppEnv) GetAuthRegistry() model.AuthRegistry { @@ -288,14 +287,36 @@ func (ae *AppEnv) IsEdgeRouterOnline(id string) bool { } func (ae *AppEnv) GetMetricsRegistry() metrics.Registry { - return ae.HostController.GetNetwork().GetMetricsRegistry() + return ae.HostController.GetMetricsRegistry() } func (ae *AppEnv) GetFingerprintGenerator() cert.FingerprintGenerator { return ae.FingerprintGenerator } +func (ae *AppEnv) GetRaftInfo() (string, string, string) { + return ae.HostController.GetRaftInfo() +} + +func (ae *AppEnv) GetApiAddresses() (map[string][]event.ApiAddress, []byte) { + return ae.HostController.GetApiAddresses() +} + +func (ae *AppEnv) GetCloseNotifyChannel() <-chan struct{} { + return ae.HostController.GetCloseNotifyChannel() +} + +func (ae *AppEnv) GetPeerSigners() []*x509.Certificate { + return ae.HostController.GetPeerSigners() +} + +func (ae *AppEnv) GetCommandDispatcher() command.Dispatcher { + return ae.HostController.GetCommandDispatcher() +} + type HostController interface { + GetConfig() *config.Config + GetEnv() *AppEnv RegisterAgentBindHandler(bindHandler channel.BindHandler) RegisterXctrl(x xctrl.Xctrl) error RegisterXmgmt(x xmgmt.Xmgmt) error @@ -306,13 +327,15 @@ type HostController interface { Identity() identity.Identity IsRaftEnabled() bool IsRaftLeader() bool + GetDb() boltz.Db + GetCommandDispatcher() command.Dispatcher GetPeerSigners() []*x509.Certificate GetEventDispatcher() event.Dispatcher GetRaftIndex() uint64 GetPeerAddresses() []string GetRaftInfo() (string, string, string) GetApiAddresses() (map[string][]event.ApiAddress, []byte) - GetConfig() *controller.Config + GetMetricsRegistry() metrics.Registry } type Schemes struct { @@ -621,7 +644,12 @@ func ProcessAuthQueries(ae *AppEnv, rc *response.RequestContext) { } } -func NewAppEnv(c *config.Config, host HostController) *AppEnv { +func NewAppEnv(host HostController) (*AppEnv, error) { + stores, err := db.InitStores(host.GetDb(), host.GetCommandDispatcher().GetRateLimiter()) + if err != nil { + return nil, err + } + clientSpec, err := loads.Embedded(clientServer.SwaggerJSON, clientServer.FlatSwaggerJSON) if err != nil { pfxlog.Logger().Fatalln(err) @@ -638,8 +666,10 @@ func NewAppEnv(c *config.Config, host HostController) *AppEnv { managementApi := managementOperations.NewZitiEdgeManagementAPI(managementSpec) managementApi.ServeError = ServeError + c := host.GetConfig().Edge + ae := &AppEnv{ - Config: c, + Stores: stores, Versions: &ziti.Versions{ Api: "1.0.0", EnrollmentApi: "1.0.0", @@ -659,10 +689,11 @@ func NewAppEnv(c *config.Config, host HostController) *AppEnv { WorkTimerMetric: metricAuthLimiterWorkTimer, QueueSizeMetric: metricAuthLimiterCurrentQueuedCount, WindowSizeMetric: metricAuthLimiterCurrentWindowSize, - }, host.GetNetwork().GetMetricsRegistry(), host.GetCloseNotifyChannel()), + }, host.GetMetricsRegistry(), host.GetCloseNotifyChannel()), + TraceManager: NewTraceManager(host.GetCloseNotifyChannel()), } - ae.identityRefreshMeter = ae.GetHostController().GetNetwork().GetMetricsRegistry().Meter("identity.refresh") + ae.identityRefreshMeter = host.GetMetricsRegistry().Meter("identity.refresh") clientApi.APIAuthorizer = authorizer{} managementApi.APIAuthorizer = authorizer{} @@ -711,9 +742,12 @@ func NewAppEnv(c *config.Config, host HostController) *AppEnv { managementApi.ZtSessionAuth = clientApi.ZtSessionAuth managementApi.Oauth2Auth = clientApi.Oauth2Auth - ae.ApiClientCsrSigner = cert.NewClientSigner(ae.Config.Enrollment.SigningCert.Cert().Leaf, ae.Config.Enrollment.SigningCert.Cert().PrivateKey) - ae.ApiServerCsrSigner = cert.NewServerSigner(ae.Config.Enrollment.SigningCert.Cert().Leaf, ae.Config.Enrollment.SigningCert.Cert().PrivateKey) - ae.ControlClientCsrSigner = cert.NewClientSigner(ae.Config.Enrollment.SigningCert.Cert().Leaf, ae.Config.Enrollment.SigningCert.Cert().PrivateKey) + if host.GetConfig().Edge.Enabled { + enrollmentCert := host.GetConfig().Edge.Enrollment.SigningCert.Cert() + ae.ApiClientCsrSigner = cert.NewClientSigner(enrollmentCert.Leaf, enrollmentCert.PrivateKey) + ae.ApiServerCsrSigner = cert.NewServerSigner(enrollmentCert.Leaf, enrollmentCert.PrivateKey) + ae.ControlClientCsrSigner = cert.NewClientSigner(enrollmentCert.Leaf, enrollmentCert.PrivateKey) + } ae.FingerprintGenerator = cert.NewFingerprintGenerator() @@ -722,13 +756,16 @@ func NewAppEnv(c *config.Config, host HostController) *AppEnv { log.WithField("cause", err).Fatal("could not load schemas") } - return ae + ae.Managers = model.NewManagers() + ae.Managers.Init(ae) + + return ae, nil } func (ae *AppEnv) InitPersistence() error { var err error - stores := ae.HostController.GetNetwork().GetStores() + stores := ae.GetStores() stores.EventualEventer.AddListener(db.EventualEventAddedName, func(i ...interface{}) { if len(i) == 0 { @@ -737,7 +774,7 @@ func (ae *AppEnv) InitPersistence() error { } if event, ok := i[0].(*db.EventualEventAdded); ok { - gauge := ae.GetHostController().GetNetwork().GetMetricsRegistry().Gauge(EventualEventsGauge) + gauge := ae.GetHostController().GetMetricsRegistry().Gauge(EventualEventsGauge) gauge.Update(event.Total) } else { pfxlog.Logger().Errorf("could not update metrics for %s gauge on add, event argument was %T expected *EventualEventAdded", EventualEventsGauge, i[0]) @@ -750,15 +787,14 @@ func (ae *AppEnv) InitPersistence() error { } if event, ok := i[0].(*db.EventualEventRemoved); ok { - gauge := ae.GetHostController().GetNetwork().GetMetricsRegistry().Gauge(EventualEventsGauge) + gauge := ae.GetHostController().GetMetricsRegistry().Gauge(EventualEventsGauge) gauge.Update(event.Total) } else { pfxlog.Logger().Errorf("could not update metrics for %s gauge on remove, event argument was %T expected *EventualEventRemoved", EventualEventsGauge, i[0]) } }) - ae.Managers = model.InitEntityManagers(ae) - ae.GetHostController().GetNetwork().GetEventDispatcher().(*events.Dispatcher).InitializeEdgeEvents(stores) + ae.GetHostController().GetEventDispatcher().(*events.Dispatcher).InitializeEdgeEvents(stores) db.ServiceEvents.AddServiceEventHandler(ae.HandleServiceEvent) stores.Identity.AddEntityIdListener(ae.IdentityRefreshMap.Remove, boltz.EntityDeletedAsync) @@ -933,7 +969,7 @@ func (ae *AppEnv) IsAllowed(responderFunc func(ae *AppEnv, rc *response.RequestC responderFunc(ae, rc) if !rc.StartTime.IsZero() { - timer := ae.GetHostController().GetNetwork().GetMetricsRegistry().Timer(getMetricTimerName(rc.Request)) + timer := ae.GetHostController().GetMetricsRegistry().Timer(getMetricTimerName(rc.Request)) timer.UpdateSince(rc.StartTime) } else { pfxlog.Logger().WithFields(map[string]interface{}{ @@ -965,5 +1001,5 @@ func (ae *AppEnv) OidcIssuer() string { } func (ae *AppEnv) RootIssuer() string { - return "https://" + ae.Config.Api.Address + return "https://" + ae.GetConfig().Edge.Api.Address } diff --git a/controller/env/broker.go b/controller/env/broker.go index 5e6be2c78..1d8a832ff 100644 --- a/controller/env/broker.go +++ b/controller/env/broker.go @@ -25,7 +25,7 @@ import ( "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/event" - "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/model" "go.etcd.io/bbolt" "sync" ) @@ -104,7 +104,7 @@ func (broker *Broker) GetReceiveHandlers() []channel.TypedReceiveHandler { return broker.routerSyncStrategy.GetReceiveHandlers() } -func (broker *Broker) RouterConnected(router *network.Router) { +func (broker *Broker) RouterConnected(router *model.Router) { go func() { fingerprint := "" if router != nil && router.Fingerprint != nil { @@ -128,7 +128,7 @@ func (broker *Broker) RouterConnected(router *network.Router) { }() } -func (broker *Broker) RouterDisconnected(r *network.Router) { +func (broker *Broker) RouterDisconnected(r *model.Router) { go func() { pfxlog.Logger().WithField("routerId", r.Id). WithField("routerName", r.Name). @@ -162,7 +162,7 @@ func (broker *Broker) apiSessionCertificateDeleted(entity *db.ApiSessionCertific func (broker *Broker) apiSessionCertificateHandler(delete bool, apiSessionCert *db.ApiSessionCertificate) { var apiSession *db.ApiSession var err error - err = broker.ae.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err = broker.ae.GetDb().View(func(tx *bbolt.Tx) error { apiSession, err = broker.ae.GetStores().ApiSession.LoadById(tx, apiSessionCert.ApiSessionId) return err }) diff --git a/controller/env/sync.go b/controller/env/sync.go index c5d32bfd4..a325fa888 100644 --- a/controller/env/sync.go +++ b/controller/env/sync.go @@ -22,7 +22,6 @@ import ( "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/network" "sync" ) @@ -68,8 +67,8 @@ type RouterSyncStrategy interface { // This is intended for API Session but additional state is possible. Implementations may bind additional // handlers to the channel. type RouterConnectionHandler interface { - RouterConnected(edgeRouter *model.EdgeRouter, router *network.Router) - RouterDisconnected(router *network.Router) + RouterConnected(edgeRouter *model.EdgeRouter, router *model.Router) + RouterDisconnected(router *model.Router) GetReceiveHandlers() []channel.TypedReceiveHandler } diff --git a/controller/events/dispatcher_router.go b/controller/events/dispatcher_router.go index cb9cd74be..585ddb0d4 100644 --- a/controller/events/dispatcher_router.go +++ b/controller/events/dispatcher_router.go @@ -18,6 +18,7 @@ package events import ( "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "reflect" @@ -70,15 +71,15 @@ type routerEventAdapter struct { *Dispatcher } -func (self *routerEventAdapter) RouterConnected(r *network.Router) { +func (self *routerEventAdapter) RouterConnected(r *model.Router) { self.routerChange(event.RouterOnline, r, true) } -func (self *routerEventAdapter) RouterDisconnected(r *network.Router) { +func (self *routerEventAdapter) RouterDisconnected(r *model.Router) { self.routerChange(event.RouterOffline, r, false) } -func (self *routerEventAdapter) routerChange(eventType event.RouterEventType, r *network.Router, online bool) { +func (self *routerEventAdapter) routerChange(eventType event.RouterEventType, r *model.Router, online bool) { evt := &event.RouterEvent{ Namespace: event.RouterEventsNs, EventType: eventType, diff --git a/controller/events/dispatcher_terminator.go b/controller/events/dispatcher_terminator.go index a089f4d4c..894b4668c 100644 --- a/controller/events/dispatcher_terminator.go +++ b/controller/events/dispatcher_terminator.go @@ -21,6 +21,7 @@ import ( "github.com/openziti/storage/boltz" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/xt" "github.com/pkg/errors" @@ -127,15 +128,15 @@ type terminatorEventAdapter struct { Dispatcher *Dispatcher } -func (self *terminatorEventAdapter) RouterConnected(r *network.Router) { +func (self *terminatorEventAdapter) RouterConnected(r *model.Router) { self.routerChange(event.TerminatorRouterOnline, r) } -func (self *terminatorEventAdapter) RouterDisconnected(r *network.Router) { +func (self *terminatorEventAdapter) RouterDisconnected(r *model.Router) { self.routerChange(event.TerminatorRouterOffline, r) } -func (self *terminatorEventAdapter) routerChange(eventType event.TerminatorEventType, r *network.Router) { +func (self *terminatorEventAdapter) routerChange(eventType event.TerminatorEventType, r *model.Router) { var terminators []*db.Terminator err := self.Network.GetDb().View(func(tx *bbolt.Tx) error { cursor := self.Network.GetStores().Router.GetRelatedEntitiesCursor(tx, r.Id, db.EntityTypeTerminators, true) @@ -175,12 +176,12 @@ func (self *terminatorEventAdapter) terminatorDeleted(terminator *db.Terminator) } func (self *terminatorEventAdapter) terminatorChanged(eventType event.TerminatorEventType, terminator *db.Terminator) { - terminator = self.Network.Services.NotifyTerminatorChanged(terminator) + terminator = self.Network.Service.NotifyTerminatorChanged(terminator) self.createTerminatorEvent(eventType, terminator) } func (self *terminatorEventAdapter) createTerminatorEvent(eventType event.TerminatorEventType, terminator *db.Terminator) { - service, _ := self.Network.Services.Read(terminator.Service) + service, _ := self.Network.Service.Read(terminator.Service) totalTerminators := -1 usableDefaultTerminators := -1 diff --git a/controller/handler_ctrl/base.go b/controller/handler_ctrl/base.go index 9614431c8..e6259d647 100644 --- a/controller/handler_ctrl/base.go +++ b/controller/handler_ctrl/base.go @@ -19,11 +19,12 @@ package handler_ctrl import ( "github.com/openziti/channel/v2" "github.com/openziti/ziti/controller/change" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" ) type baseHandler struct { - router *network.Router + router *model.Router network *network.Network } diff --git a/controller/handler_ctrl/bind.go b/controller/handler_ctrl/bind.go index a71508d22..56552cde5 100644 --- a/controller/handler_ctrl/bind.go +++ b/controller/handler_ctrl/bind.go @@ -18,6 +18,7 @@ package handler_ctrl import ( "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "github.com/sirupsen/logrus" "time" @@ -34,12 +35,12 @@ import ( type bindHandler struct { heartbeatOptions *channel.HeartbeatOptions - router *network.Router + router *model.Router network *network.Network xctrls []xctrl.Xctrl } -func newBindHandler(heartbeatOptions *channel.HeartbeatOptions, router *network.Router, network *network.Network, xctrls []xctrl.Xctrl) channel.BindHandler { +func newBindHandler(heartbeatOptions *channel.HeartbeatOptions, router *model.Router, network *network.Network, xctrls []xctrl.Xctrl) channel.BindHandler { return &bindHandler{ heartbeatOptions: heartbeatOptions, router: router, diff --git a/controller/handler_ctrl/circuit_confirmation.go b/controller/handler_ctrl/circuit_confirmation.go index 13ca3a711..b41168f98 100644 --- a/controller/handler_ctrl/circuit_confirmation.go +++ b/controller/handler_ctrl/circuit_confirmation.go @@ -20,6 +20,7 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" @@ -28,10 +29,10 @@ import ( type circuitConfirmationHandler struct { n *network.Network - r *network.Router + r *model.Router } -func newCircuitConfirmationHandler(n *network.Network, r *network.Router) *circuitConfirmationHandler { +func newCircuitConfirmationHandler(n *network.Network, r *model.Router) *circuitConfirmationHandler { return &circuitConfirmationHandler{n, r} } @@ -56,10 +57,10 @@ func (self *circuitConfirmationHandler) HandleReceive(msg *channel.Message, _ ch } } -func (self *circuitConfirmationHandler) checkCircuitMaxIdle(circuit *network.Circuit, confirm *ctrl_pb.CircuitConfirmation) { +func (self *circuitConfirmationHandler) checkCircuitMaxIdle(circuit *model.Circuit, confirm *ctrl_pb.CircuitConfirmation) { log := logrus.WithField("routerId", self.r.Id).WithField("circuitId", circuit.Id) - service, _ := self.n.Services.Read(circuit.ServiceId) + service, _ := self.n.Service.Read(circuit.ServiceId) if service == nil { log.Info("service for circuit gone, removing idle circuit") if err := self.n.RemoveCircuit(circuit.Id, true); err != nil { diff --git a/controller/handler_ctrl/circuit_request.go b/controller/handler_ctrl/circuit_request.go index b4267e4e5..5bbf61c87 100644 --- a/controller/handler_ctrl/circuit_request.go +++ b/controller/handler_ctrl/circuit_request.go @@ -17,25 +17,26 @@ package handler_ctrl import ( + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/xt" "google.golang.org/protobuf/proto" "time" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" + "github.com/openziti/identity" "github.com/openziti/ziti/common/ctrl_msg" "github.com/openziti/ziti/common/logcontext" "github.com/openziti/ziti/common/pb/ctrl_pb" - "github.com/openziti/identity" + "github.com/openziti/ziti/controller/network" ) type circuitRequestHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newCircuitRequestHandler(r *network.Router, network *network.Network) *circuitRequestHandler { +func newCircuitRequestHandler(r *model.Router, network *network.Network) *circuitRequestHandler { return &circuitRequestHandler{r: r, network: network} } @@ -55,8 +56,8 @@ func (h *circuitRequestHandler) HandleReceive(msg *channel.Message, ch channel.C go func() { id := &identity.TokenId{Token: request.IngressId, Data: request.PeerData} service := request.Service - if _, err := h.network.Managers.Services.Read(service); err != nil { - if id, _ := h.network.Managers.Services.GetIdForName(service); id != "" { + if _, err := h.network.Managers.Service.Read(service); err != nil { + if id, _ := h.network.Managers.Service.GetIdForName(service); id != "" { service = id } } @@ -96,7 +97,7 @@ func (h *circuitRequestHandler) HandleReceive(msg *channel.Message, ch channel.C } } -func (h *circuitRequestHandler) newCircuitCreateParms(serviceId string, sourceRouter *network.Router, clientId *identity.TokenId) network.CreateCircuitParams { +func (h *circuitRequestHandler) newCircuitCreateParms(serviceId string, sourceRouter *model.Router, clientId *identity.TokenId) model.CreateCircuitParams { return &circuitParams{ serviceId: serviceId, sourceRouter: sourceRouter, @@ -108,7 +109,7 @@ func (h *circuitRequestHandler) newCircuitCreateParms(serviceId string, sourceRo type circuitParams struct { serviceId string - sourceRouter *network.Router + sourceRouter *model.Router clientId *identity.TokenId ctx logcontext.Context deadline time.Time @@ -118,7 +119,7 @@ func (self *circuitParams) GetServiceId() string { return self.serviceId } -func (self *circuitParams) GetSourceRouter() *network.Router { +func (self *circuitParams) GetSourceRouter() *model.Router { return self.sourceRouter } diff --git a/controller/handler_ctrl/close.go b/controller/handler_ctrl/close.go index 21f18b185..8c8f15c16 100644 --- a/controller/handler_ctrl/close.go +++ b/controller/handler_ctrl/close.go @@ -19,15 +19,16 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" ) type closeHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newCloseHandler(r *network.Router, network *network.Network) *closeHandler { +func newCloseHandler(r *model.Router, network *network.Network) *closeHandler { return &closeHandler{r: r, network: network} } diff --git a/controller/handler_ctrl/create_terminator.go b/controller/handler_ctrl/create_terminator.go index eb244cece..2947ecfd3 100644 --- a/controller/handler_ctrl/create_terminator.go +++ b/controller/handler_ctrl/create_terminator.go @@ -20,9 +20,10 @@ import ( "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" "math" ) @@ -31,7 +32,7 @@ type createTerminatorHandler struct { baseHandler } -func newCreateTerminatorHandler(network *network.Network, router *network.Router) *createTerminatorHandler { +func newCreateTerminatorHandler(network *network.Network, router *model.Router) *createTerminatorHandler { return &createTerminatorHandler{ baseHandler: baseHandler{ network: network, @@ -60,7 +61,7 @@ func (self *createTerminatorHandler) handleCreateTerminator(msg *channel.Message return } - terminator := &network.Terminator{ + terminator := &model.Terminator{ Service: request.ServiceId, Router: self.router.Id, Binding: request.Binding, @@ -72,7 +73,7 @@ func (self *createTerminatorHandler) handleCreateTerminator(msg *channel.Message Cost: uint16(request.Cost), } - if err := self.network.Terminators.Create(terminator, self.newChangeContext(ch, "fabric.create.terminator")); err == nil { + if err := self.network.Terminator.Create(terminator, self.newChangeContext(ch, "fabric.create.terminator")); err == nil { pfxlog.Logger().Infof("created terminator [t/%s]", terminator.Id) handler_common.SendSuccess(msg, ch, terminator.Id) } else { diff --git a/controller/handler_ctrl/decommission.go b/controller/handler_ctrl/decommission.go index 9215281cc..c18a8a6b3 100644 --- a/controller/handler_ctrl/decommission.go +++ b/controller/handler_ctrl/decommission.go @@ -21,6 +21,7 @@ import ( "github.com/openziti/channel/v2" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" ) @@ -28,7 +29,7 @@ type decommissionRouterHandler struct { baseHandler } -func newDecommissionRouterHandler(router *network.Router, network *network.Network) *decommissionRouterHandler { +func newDecommissionRouterHandler(router *model.Router, network *network.Network) *decommissionRouterHandler { return &decommissionRouterHandler{ baseHandler: baseHandler{ router: router, @@ -46,7 +47,7 @@ func (self *decommissionRouterHandler) HandleReceive(msg *channel.Message, ch ch log = log.WithField("routerId", self.router.Id) go func() { - if err := self.network.Routers.Delete(self.router.Id, self.newChangeContext(ch, "decommission.router")); err == nil { + if err := self.network.Router.Delete(self.router.Id, self.newChangeContext(ch, "decommission.router")); err == nil { // we don't send success because deleting the router will close the router connection log.Debug("router decommission successful") } else { diff --git a/controller/handler_ctrl/dequiesce_router.go b/controller/handler_ctrl/dequiesce_router.go index b599fbb97..14aa66cda 100644 --- a/controller/handler_ctrl/dequiesce_router.go +++ b/controller/handler_ctrl/dequiesce_router.go @@ -19,16 +19,17 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" ) type dequiesceRouterHandler struct { baseHandler } -func newDequiesceRouterHandler(router *network.Router, network *network.Network) *dequiesceRouterHandler { +func newDequiesceRouterHandler(router *model.Router, network *network.Network) *dequiesceRouterHandler { return &dequiesceRouterHandler{ baseHandler: baseHandler{ router: router, @@ -46,7 +47,7 @@ func (self *dequiesceRouterHandler) HandleReceive(msg *channel.Message, ch chann log = log.WithField("routerId", self.router.Id) go func() { - if err := self.network.Routers.DequiesceRouter(self.router, self.newChangeContext(ch, "dequiesce.router")); err == nil { + if err := self.network.Router.DequiesceRouter(self.router, self.newChangeContext(ch, "dequiesce.router")); err == nil { handler_common.SendSuccess(msg, ch, "router dequiesced") log.Debug("router dequiesce successful") } else { diff --git a/controller/handler_ctrl/fault.go b/controller/handler_ctrl/fault.go index b3f2a5b34..42e7523ab 100644 --- a/controller/handler_ctrl/fault.go +++ b/controller/handler_ctrl/fault.go @@ -23,6 +23,7 @@ import ( "github.com/openziti/channel/v2/protobufs" "github.com/openziti/ziti/common/pb/ctrl_pb" "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" @@ -30,11 +31,11 @@ import ( ) type faultHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newFaultHandler(r *network.Router, network *network.Network) *faultHandler { +func newFaultHandler(r *model.Router, network *network.Network) *faultHandler { return &faultHandler{r: r, network: network} } diff --git a/controller/handler_ctrl/link_connected.go b/controller/handler_ctrl/link_connected.go index 38793cb39..3945a19b7 100644 --- a/controller/handler_ctrl/link_connected.go +++ b/controller/handler_ctrl/link_connected.go @@ -19,17 +19,18 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) type linkConnectedHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newLinkConnectedHandler(r *network.Router, network *network.Network) *linkConnectedHandler { +func newLinkConnectedHandler(r *model.Router, network *network.Network) *linkConnectedHandler { return &linkConnectedHandler{r: r, network: network} } diff --git a/controller/handler_ctrl/quiesce_router.go b/controller/handler_ctrl/quiesce_router.go index f0b04bdec..e26895666 100644 --- a/controller/handler_ctrl/quiesce_router.go +++ b/controller/handler_ctrl/quiesce_router.go @@ -19,16 +19,17 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" ) type quiesceRouterHandler struct { baseHandler } -func newQuiesceRouterHandler(router *network.Router, network *network.Network) *quiesceRouterHandler { +func newQuiesceRouterHandler(router *model.Router, network *network.Network) *quiesceRouterHandler { return &quiesceRouterHandler{ baseHandler: baseHandler{ router: router, @@ -46,7 +47,7 @@ func (self *quiesceRouterHandler) HandleReceive(msg *channel.Message, ch channel log = log.WithField("routerId", self.router.Id) go func() { - if err := self.network.Routers.QuiesceRouter(self.router, self.newChangeContext(ch, "quiesce.router")); err == nil { + if err := self.network.Router.QuiesceRouter(self.router, self.newChangeContext(ch, "quiesce.router")); err == nil { handler_common.SendSuccess(msg, ch, "router quiesced") log.Debug("router quiesce successful") } else { diff --git a/controller/handler_ctrl/remove_terminator.go b/controller/handler_ctrl/remove_terminator.go index f0d539b3d..e5dd2c47a 100644 --- a/controller/handler_ctrl/remove_terminator.go +++ b/controller/handler_ctrl/remove_terminator.go @@ -19,9 +19,10 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) @@ -29,7 +30,7 @@ type removeTerminatorHandler struct { baseHandler } -func newRemoveTerminatorHandler(network *network.Network, router *network.Router) *removeTerminatorHandler { +func newRemoveTerminatorHandler(network *network.Network, router *model.Router) *removeTerminatorHandler { return &removeTerminatorHandler{ baseHandler: baseHandler{ router: router, @@ -57,13 +58,13 @@ func (self *removeTerminatorHandler) HandleReceive(msg *channel.Message, ch chan func (self *removeTerminatorHandler) handleRemoveTerminator(msg *channel.Message, ch channel.Channel, request *ctrl_pb.RemoveTerminatorRequest) { log := pfxlog.ContextLogger(ch.Label()) - terminator, err := self.network.Terminators.Read(request.TerminatorId) + terminator, err := self.network.Terminator.Read(request.TerminatorId) if err != nil { handler_common.SendFailure(msg, ch, err.Error()) return } - if err := self.network.Terminators.Delete(request.TerminatorId, self.newChangeContext(ch, "fabric.remove.terminator")); err == nil { + if err := self.network.Terminator.Delete(request.TerminatorId, self.newChangeContext(ch, "fabric.remove.terminator")); err == nil { log. WithField("routerId", ch.Id()). WithField("serviceId", terminator.Service). diff --git a/controller/handler_ctrl/remove_terminators.go b/controller/handler_ctrl/remove_terminators.go index 18689b3c4..4812c583c 100644 --- a/controller/handler_ctrl/remove_terminators.go +++ b/controller/handler_ctrl/remove_terminators.go @@ -22,6 +22,7 @@ import ( "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" "github.com/openziti/ziti/controller/command" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) @@ -30,7 +31,7 @@ type removeTerminatorsHandler struct { baseHandler } -func newRemoveTerminatorsHandler(network *network.Network, router *network.Router) *removeTerminatorsHandler { +func newRemoveTerminatorsHandler(network *network.Network, router *model.Router) *removeTerminatorsHandler { return &removeTerminatorsHandler{ baseHandler: baseHandler{ router: router, @@ -58,7 +59,7 @@ func (self *removeTerminatorsHandler) HandleReceive(msg *channel.Message, ch cha func (self *removeTerminatorsHandler) handleRemoveTerminators(msg *channel.Message, ch channel.Channel, request *ctrl_pb.RemoveTerminatorsRequest) { log := pfxlog.ContextLogger(ch.Label()) - if err := self.network.Terminators.DeleteBatch(request.TerminatorIds, self.newChangeContext(ch, "fabric.remove.terminators.batch")); err == nil { + if err := self.network.Terminator.DeleteBatch(request.TerminatorIds, self.newChangeContext(ch, "fabric.remove.terminators.batch")); err == nil { log. WithField("routerId", ch.Id()). WithField("terminatorIds", request.TerminatorIds). diff --git a/controller/handler_ctrl/route_result.go b/controller/handler_ctrl/route_result.go index 1f8d399a5..e02a426ce 100644 --- a/controller/handler_ctrl/route_result.go +++ b/controller/handler_ctrl/route_result.go @@ -20,20 +20,21 @@ import ( "bytes" "encoding/binary" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" - "github.com/openziti/ziti/controller/xt" "github.com/openziti/ziti/common/ctrl_msg" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/xt" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" ) type routeResultHandler struct { network *network.Network - r *network.Router + r *model.Router } -func newRouteResultHandler(network *network.Network, r *network.Router) *routeResultHandler { +func newRouteResultHandler(network *network.Network, r *model.Router) *routeResultHandler { return &routeResultHandler{ network: network, r: r, diff --git a/controller/handler_ctrl/router_link.go b/controller/handler_ctrl/router_link.go index 8df485e0d..ca93742ce 100644 --- a/controller/handler_ctrl/router_link.go +++ b/controller/handler_ctrl/router_link.go @@ -20,16 +20,17 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) type routerLinkHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newRouterLinkHandler(r *network.Router, network *network.Network) *routerLinkHandler { +func newRouterLinkHandler(r *model.Router, network *network.Network) *routerLinkHandler { return &routerLinkHandler{r: r, network: network} } diff --git a/controller/handler_ctrl/update_terminator.go b/controller/handler_ctrl/update_terminator.go index 38af85abf..5da283d1a 100644 --- a/controller/handler_ctrl/update_terminator.go +++ b/controller/handler_ctrl/update_terminator.go @@ -19,12 +19,13 @@ package handler_ctrl import ( "fmt" "github.com/openziti/channel/v2" + "github.com/openziti/ziti/common/handler_common" + "github.com/openziti/ziti/common/pb/ctrl_pb" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/controller/xt" - "github.com/openziti/ziti/common/handler_common" - "github.com/openziti/ziti/common/pb/ctrl_pb" log "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" "math" @@ -34,7 +35,7 @@ type updateTerminatorHandler struct { baseHandler } -func newUpdateTerminatorHandler(network *network.Network, router *network.Router) *updateTerminatorHandler { +func newUpdateTerminatorHandler(network *network.Network, router *model.Router) *updateTerminatorHandler { return &updateTerminatorHandler{ baseHandler: baseHandler{ router: router, @@ -58,7 +59,7 @@ func (self *updateTerminatorHandler) HandleReceive(msg *channel.Message, ch chan } func (self *updateTerminatorHandler) handleUpdateTerminator(msg *channel.Message, ch channel.Channel, request *ctrl_pb.UpdateTerminatorRequest) { - terminator, err := self.network.Terminators.Read(request.TerminatorId) + terminator, err := self.network.Terminator.Read(request.TerminatorId) if err != nil { handler_common.SendFailure(msg, ch, err.Error()) return @@ -97,7 +98,7 @@ func (self *updateTerminatorHandler) handleUpdateTerminator(msg *channel.Message checker[db.FieldTerminatorPrecedence] = struct{}{} } - if err := self.network.Terminators.Update(terminator, checker, self.newChangeContext(ch, "fabric.update.terminator")); err != nil { + if err := self.network.Terminator.Update(terminator, checker, self.newChangeContext(ch, "fabric.update.terminator")); err != nil { handler_common.SendFailure(msg, ch, err.Error()) return } diff --git a/controller/handler_ctrl/verify_router.go b/controller/handler_ctrl/verify_router.go index e7971216e..8d6afeea5 100644 --- a/controller/handler_ctrl/verify_router.go +++ b/controller/handler_ctrl/verify_router.go @@ -19,18 +19,19 @@ package handler_ctrl import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/common/handler_common" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) type verifyRouterHandler struct { - r *network.Router + r *model.Router network *network.Network } -func newVerifyRouterHandler(r *network.Router, network *network.Network) *verifyRouterHandler { +func newVerifyRouterHandler(r *model.Router, network *network.Network) *verifyRouterHandler { return &verifyRouterHandler{r: r, network: network} } diff --git a/controller/handler_edge_ctrl/common.go b/controller/handler_edge_ctrl/common.go index 8d9ba0344..a9173f026 100644 --- a/controller/handler_edge_ctrl/common.go +++ b/controller/handler_edge_ctrl/common.go @@ -110,10 +110,10 @@ type baseSessionRequestContext struct { handler requestHandler msg *channel.Message err controllerError - sourceRouter *network.Router + sourceRouter *model.Router session *model.Session apiSession *model.ApiSession - service *model.Service + service *model.EdgeService newSession bool logContext logcontext.Context env model.Env @@ -426,11 +426,11 @@ func (self *baseSessionRequestContext) loadService() { } } -func (self *baseSessionRequestContext) verifyTerminator(terminatorId string, binding string) *network.Terminator { +func (self *baseSessionRequestContext) verifyTerminator(terminatorId string, binding string) *model.Terminator { if self.err == nil { - var terminator *network.Terminator + var terminator *model.Terminator var err error - terminator, err = self.handler.getNetwork().Terminators.Read(terminatorId) + terminator, err = self.handler.getNetwork().Terminator.Read(terminatorId) if err != nil { if boltz.IsErrNotFoundErr(err) { @@ -498,7 +498,7 @@ func (self *baseSessionRequestContext) verifyTerminatorId(id string) { } } -func (self *baseSessionRequestContext) updateTerminator(terminator *network.Terminator, request UpdateTerminatorRequest, ctx *change.Context) { +func (self *baseSessionRequestContext) updateTerminator(terminator *model.Terminator, request UpdateTerminatorRequest, ctx *change.Context) { if self.err == nil { checker := fields.UpdatedFieldsMap{} @@ -526,11 +526,11 @@ func (self *baseSessionRequestContext) updateTerminator(terminator *network.Term checker[db.FieldTerminatorPrecedence] = struct{}{} } - self.err = internalError(self.handler.getNetwork().Terminators.Update(terminator, checker, ctx)) + self.err = internalError(self.handler.getNetwork().Terminator.Update(terminator, checker, ctx)) } } -func (self *baseSessionRequestContext) newCircuitCreateParms(serviceId string, peerData map[uint32][]byte) network.CreateCircuitParams { +func (self *baseSessionRequestContext) newCircuitCreateParms(serviceId string, peerData map[uint32][]byte) model.CreateCircuitParams { return &sessionCircuitParams{ serviceId: serviceId, sourceRouter: self.sourceRouter, @@ -541,7 +541,7 @@ func (self *baseSessionRequestContext) newCircuitCreateParms(serviceId string, p } } -func (self *baseSessionRequestContext) newTunnelCircuitCreateParms(serviceId string, peerData map[uint32][]byte) network.CreateCircuitParams { +func (self *baseSessionRequestContext) newTunnelCircuitCreateParms(serviceId string, peerData map[uint32][]byte) model.CreateCircuitParams { return &tunnelCircuitParams{ serviceId: serviceId, sourceRouter: self.sourceRouter, @@ -552,10 +552,10 @@ func (self *baseSessionRequestContext) newTunnelCircuitCreateParms(serviceId str } } -type circuitParamsFactory = func(serviceId string, peerData map[uint32][]byte) network.CreateCircuitParams +type circuitParamsFactory = func(serviceId string, peerData map[uint32][]byte) model.CreateCircuitParams -func (self *baseSessionRequestContext) createCircuit(terminatorInstanceId string, peerData map[uint32][]byte, paramsFactory circuitParamsFactory) (*network.Circuit, map[uint32][]byte) { - var circuit *network.Circuit +func (self *baseSessionRequestContext) createCircuit(terminatorInstanceId string, peerData map[uint32][]byte, paramsFactory circuitParamsFactory) (*model.Circuit, map[uint32][]byte) { + var circuit *model.Circuit returnPeerData := map[uint32][]byte{} if self.err == nil { @@ -606,7 +606,7 @@ func (self *baseSessionRequestContext) createCircuit(terminatorInstanceId string type sessionCircuitParams struct { serviceId string - sourceRouter *network.Router + sourceRouter *model.Router clientId *identity.TokenId logCtx logcontext.Context deadline time.Time @@ -617,7 +617,7 @@ func (self *sessionCircuitParams) GetServiceId() string { return self.serviceId } -func (self *sessionCircuitParams) GetSourceRouter() *network.Router { +func (self *sessionCircuitParams) GetSourceRouter() *model.Router { return self.sourceRouter } @@ -651,7 +651,7 @@ func (self *sessionCircuitParams) GetDeadline() time.Time { type tunnelCircuitParams struct { serviceId string - sourceRouter *network.Router + sourceRouter *model.Router clientId *identity.TokenId logCtx logcontext.Context deadline time.Time @@ -662,7 +662,7 @@ func (self *tunnelCircuitParams) GetServiceId() string { return self.serviceId } -func (self *tunnelCircuitParams) GetSourceRouter() *network.Router { +func (self *tunnelCircuitParams) GetSourceRouter() *model.Router { return self.sourceRouter } diff --git a/controller/handler_edge_ctrl/common_tunnel.go b/controller/handler_edge_ctrl/common_tunnel.go index d4bb4b0db..c69fabe31 100644 --- a/controller/handler_edge_ctrl/common_tunnel.go +++ b/controller/handler_edge_ctrl/common_tunnel.go @@ -152,7 +152,7 @@ func (self *baseTunnelRequestContext) ensureApiSessionLocking(configTypes []stri IPAddress: self.handler.getChannel().Underlay().GetRemoteAddr().String(), } - err := self.handler.getAppEnv().GetDbProvider().GetDb().Update(self.newTunnelChangeContext().NewMutateContext(), func(ctx boltz.MutateContext) error { + err := self.handler.getAppEnv().GetDb().Update(self.newTunnelChangeContext().NewMutateContext(), func(ctx boltz.MutateContext) error { var err error apiSession.Id, err = self.handler.getAppEnv().GetManagers().ApiSession.Create(ctx, apiSession, nil) if err != nil { diff --git a/controller/handler_edge_ctrl/create_circuit.go b/controller/handler_edge_ctrl/create_circuit.go index 9e6161b29..c02498bd0 100644 --- a/controller/handler_edge_ctrl/create_circuit.go +++ b/controller/handler_edge_ctrl/create_circuit.go @@ -24,7 +24,7 @@ import ( "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/env" - "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/model" "google.golang.org/protobuf/proto" ) @@ -77,7 +77,7 @@ func (self *createCircuitHandler) HandleReceiveCreateCircuitV1(msg *channel.Mess self.CreateCircuit(ctx, self.CreateCircuitV1Response) } -func (self *createCircuitHandler) CreateCircuitV1Response(circuitInfo *network.Circuit, peerData map[uint32][]byte) (*channel.Message, error) { +func (self *createCircuitHandler) CreateCircuitV1Response(circuitInfo *model.Circuit, peerData map[uint32][]byte) (*channel.Message, error) { response := &edge_ctrl_pb.CreateCircuitResponse{ CircuitId: circuitInfo.Id, Address: circuitInfo.Path.IngressId, @@ -108,7 +108,7 @@ func (self *createCircuitHandler) HandleReceiveCreateCircuitV2(msg *channel.Mess self.CreateCircuit(ctx, self.CreateCircuitV2Response) } -func (self *createCircuitHandler) CreateCircuitV2Response(circuitInfo *network.Circuit, peerData map[uint32][]byte) (*channel.Message, error) { +func (self *createCircuitHandler) CreateCircuitV2Response(circuitInfo *model.Circuit, peerData map[uint32][]byte) (*channel.Message, error) { response := &ctrl_msg.CreateCircuitResponse{ CircuitId: circuitInfo.Id, Address: circuitInfo.Path.IngressId, @@ -148,7 +148,7 @@ func (self *createCircuitHandler) CreateCircuit(ctx *CreateCircuitRequestContext } } -type createCircuitResponseFactory func(*network.Circuit, map[uint32][]byte) (*channel.Message, error) +type createCircuitResponseFactory func(*model.Circuit, map[uint32][]byte) (*channel.Message, error) var _ CreateCircuitRequest = (*edge_ctrl_pb.CreateCircuitRequest)(nil) diff --git a/controller/handler_edge_ctrl/create_terminator.go b/controller/handler_edge_ctrl/create_terminator.go index 0b4d8e0ac..9ddc7bc62 100644 --- a/controller/handler_edge_ctrl/create_terminator.go +++ b/controller/handler_edge_ctrl/create_terminator.go @@ -27,7 +27,6 @@ import ( "github.com/openziti/ziti/controller/idgen" "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" "math" ) @@ -94,7 +93,7 @@ func (self *createTerminatorHandler) CreateTerminator(ctx *CreateTerminatorReque id := idgen.NewUUIDString() - terminator := &network.Terminator{ + terminator := &model.Terminator{ BaseEntity: models.BaseEntity{ Id: id, IsSystem: true, @@ -116,7 +115,7 @@ func (self *createTerminatorHandler) CreateTerminator(ctx *CreateTerminatorReque Entity: terminator, Context: ctx.newChangeContext(), } - if err := self.appEnv.GetHostController().GetNetwork().Managers.Command.Dispatch(cmd); err != nil { + if err := self.appEnv.GetHostController().GetNetwork().Managers.Dispatcher.Dispatch(cmd); err != nil { self.returnError(ctx, internalError(err)) return } diff --git a/controller/handler_edge_ctrl/create_terminator_v2.go b/controller/handler_edge_ctrl/create_terminator_v2.go index 16c78ee2f..285808dad 100644 --- a/controller/handler_edge_ctrl/create_terminator_v2.go +++ b/controller/handler_edge_ctrl/create_terminator_v2.go @@ -29,7 +29,6 @@ import ( "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" @@ -107,7 +106,7 @@ func (self *createTerminatorV2Handler) CreateTerminatorV2(ctx *CreateTerminatorV return } - terminator, _ := self.getNetwork().Terminators.Read(ctx.req.Address) + terminator, _ := self.getNetwork().Terminator.Read(ctx.req.Address) if terminator != nil { if ctx.err = ctx.validateExistingTerminator(terminator, logger); ctx.err != nil { self.returnError(ctx, edge_ctrl_pb.CreateTerminatorResult_FailedIdConflict, ctx.err, logger) @@ -118,7 +117,7 @@ func (self *createTerminatorV2Handler) CreateTerminatorV2(ctx *CreateTerminatorV if terminator.Precedence != ctx.req.GetXtPrecedence() || terminator.Cost != uint16(ctx.req.Cost) { terminator.Precedence = ctx.req.GetXtPrecedence() terminator.Cost = uint16(ctx.req.Cost) - err := self.appEnv.GetHostController().GetNetwork().Terminators.Update(terminator, fields.UpdatedFieldsMap{ + err := self.appEnv.GetManagers().Terminator.Update(terminator, fields.UpdatedFieldsMap{ db.FieldTerminatorPrecedence: struct{}{}, db.FieldTerminatorCost: struct{}{}, }, ctx.newChangeContext()) @@ -129,7 +128,7 @@ func (self *createTerminatorV2Handler) CreateTerminatorV2(ctx *CreateTerminatorV } } } else { - terminator = &network.Terminator{ + terminator = &model.Terminator{ BaseEntity: models.BaseEntity{ Id: ctx.req.Address, IsSystem: true, @@ -153,9 +152,9 @@ func (self *createTerminatorV2Handler) CreateTerminatorV2(ctx *CreateTerminatorV } createStart := time.Now() - if err := self.appEnv.GetHostController().GetNetwork().Managers.Command.Dispatch(cmd); err != nil { + if err := self.appEnv.GetHostController().GetNetwork().Managers.Dispatcher.Dispatch(cmd); err != nil { // terminator might have been created while we were trying to create. - if terminator, _ = self.getNetwork().Terminators.Read(ctx.req.Address); terminator != nil { + if terminator, _ = self.getNetwork().Terminator.Read(ctx.req.Address); terminator != nil { if validateError := ctx.validateExistingTerminator(terminator, logger); validateError != nil { self.returnError(ctx, edge_ctrl_pb.CreateTerminatorResult_FailedIdConflict, validateError, logger) return @@ -218,7 +217,7 @@ func (self *CreateTerminatorV2RequestContext) GetSessionToken() string { return self.req.SessionToken } -func (self *CreateTerminatorV2RequestContext) validateExistingTerminator(terminator *network.Terminator, log *logrus.Entry) controllerError { +func (self *CreateTerminatorV2RequestContext) validateExistingTerminator(terminator *model.Terminator, log *logrus.Entry) controllerError { if terminator.Binding != common.EdgeBinding { log.WithField("binding", common.EdgeBinding). WithField("conflictingBinding", terminator.Binding). diff --git a/controller/handler_edge_ctrl/create_tunnel_terminator.go b/controller/handler_edge_ctrl/create_tunnel_terminator.go index 4cdd28103..90ba711ab 100644 --- a/controller/handler_edge_ctrl/create_tunnel_terminator.go +++ b/controller/handler_edge_ctrl/create_tunnel_terminator.go @@ -24,8 +24,8 @@ import ( "github.com/openziti/ziti/common/pb/edge_ctrl_pb" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/env" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" @@ -103,14 +103,14 @@ func (self *createTunnelTerminatorHandler) CreateTerminator(ctx *CreateTunnelTer return } - terminator, _ := self.getNetwork().Terminators.Read(ctx.req.Address) + terminator, _ := self.getNetwork().Terminator.Read(ctx.req.Address) if terminator != nil { if err := ctx.validateExistingTerminator(terminator, logger); err != nil { self.returnError(ctx, err) return } } else { - terminator = &network.Terminator{ + terminator = &model.Terminator{ BaseEntity: models.BaseEntity{ Id: ctx.req.Address, IsSystem: true, @@ -127,10 +127,9 @@ func (self *createTunnelTerminatorHandler) CreateTerminator(ctx *CreateTunnelTer HostId: ctx.session.IdentityId, } - n := self.appEnv.GetHostController().GetNetwork() - if err := n.Terminators.Create(terminator, ctx.newTunnelChangeContext()); err != nil { + if err := self.appEnv.Managers.Terminator.Create(terminator, ctx.newTunnelChangeContext()); err != nil { // terminator might have been created while we were trying to create. - if terminator, _ = self.getNetwork().Terminators.Read(ctx.req.Address); terminator != nil { + if terminator, _ = self.getNetwork().Terminator.Read(ctx.req.Address); terminator != nil { if validateError := ctx.validateExistingTerminator(terminator, logger); validateError != nil { self.returnError(ctx, validateError) return @@ -179,7 +178,7 @@ type CreateTunnelTerminatorRequestContext struct { req *edge_ctrl_pb.CreateTunnelTerminatorRequest } -func (self *CreateTunnelTerminatorRequestContext) validateExistingTerminator(terminator *network.Terminator, log *logrus.Entry) controllerError { +func (self *CreateTunnelTerminatorRequestContext) validateExistingTerminator(terminator *model.Terminator, log *logrus.Entry) controllerError { if terminator.Binding != common.TunnelBinding { log.WithField("binding", common.TunnelBinding). WithField("conflictingBinding", terminator.Binding). diff --git a/controller/handler_edge_ctrl/remove_terminator.go b/controller/handler_edge_ctrl/remove_terminator.go index 1b09c8e18..9679e76bd 100644 --- a/controller/handler_edge_ctrl/remove_terminator.go +++ b/controller/handler_edge_ctrl/remove_terminator.go @@ -88,7 +88,7 @@ func (self *removeTerminatorHandler) RemoveTerminator(ctx *RemoveTerminatorReque // to check the session again here. The session may already be deleted, and if it is, we don't // currently have a way to verify that it's associated. Also, with idempotent terminators, a // terminator may belong to a series of sessions. - err := self.getNetwork().Terminators.Delete(ctx.req.TerminatorId, ctx.newChangeContext()) + err := self.getNetwork().Terminator.Delete(ctx.req.TerminatorId, ctx.newChangeContext()) if err != nil { self.returnError(ctx, internalError(err)) return diff --git a/controller/handler_edge_ctrl/remove_tunnel_terminator.go b/controller/handler_edge_ctrl/remove_tunnel_terminator.go index 4a8b273b2..76661b80f 100644 --- a/controller/handler_edge_ctrl/remove_tunnel_terminator.go +++ b/controller/handler_edge_ctrl/remove_tunnel_terminator.go @@ -69,7 +69,7 @@ func (self *removeTunnelTerminatorHandler) RemoveTerminator(ctx *RemoveTunnelTer logger = logger.WithField("serviceId", t.Service) - err := self.getNetwork().Terminators.Delete(ctx.terminatorId, ctx.newTunnelChangeContext()) + err := self.getNetwork().Terminator.Delete(ctx.terminatorId, ctx.newTunnelChangeContext()) if err != nil { self.returnError(ctx, internalError(err)) return diff --git a/controller/handler_edge_ctrl/validate_sessions.go b/controller/handler_edge_ctrl/validate_sessions.go index 8327a5046..31581c3cf 100644 --- a/controller/handler_edge_ctrl/validate_sessions.go +++ b/controller/handler_edge_ctrl/validate_sessions.go @@ -62,7 +62,7 @@ func (self *validateSessionsHandler) validateSessions(req *edge_ctrl_pb.Validate var invalidTokens []string - err := self.getAppEnv().GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := self.getAppEnv().GetDb().View(func(tx *bbolt.Tx) error { for _, token := range req.SessionTokens { if tokenIndex.Read(tx, []byte(token)) == nil { invalidTokens = append(invalidTokens, token) diff --git a/controller/handler_mgmt/inspect.go b/controller/handler_mgmt/inspect.go index c0bcd6c29..7c1e1ecdb 100644 --- a/controller/handler_mgmt/inspect.go +++ b/controller/handler_mgmt/inspect.go @@ -44,7 +44,7 @@ func (handler *inspectHandler) HandleReceive(msg *channel.Message, ch channel.Ch response.Success = false response.Errors = append(response.Errors, fmt.Sprintf("%v: %v", handler.network.GetAppId(), err)) } else { - result := handler.network.Managers.Inspections.Inspect(request.AppRegex, request.RequestedValues) + result := handler.network.Inspections.Inspect(request.AppRegex, request.RequestedValues) response.Success = result.Success response.Errors = result.Errors for _, val := range result.Results { diff --git a/controller/handler_mgmt/stream_toggle_pipe_traces.go b/controller/handler_mgmt/stream_toggle_pipe_traces.go index d8dda7799..88597b50e 100644 --- a/controller/handler_mgmt/stream_toggle_pipe_traces.go +++ b/controller/handler_mgmt/stream_toggle_pipe_traces.go @@ -18,6 +18,7 @@ package handler_mgmt import ( "fmt" + "github.com/openziti/ziti/controller/model" "sync" "time" @@ -141,7 +142,7 @@ func getApplyResults(resultChan chan trace.ToggleApplyResult, verbosity trace.To } } -func handleResponse(router *network.Router, mgmtReq *channel.Message, msgsCh chan<- *remoteToggleResult, waitGroup *sync.WaitGroup) { +func handleResponse(router *model.Router, mgmtReq *channel.Message, msgsCh chan<- *remoteToggleResult, waitGroup *sync.WaitGroup) { defer waitGroup.Done() msg := channel.NewMessage(int32(ctrl_pb.ContentType_TogglePipeTracesRequestType), mgmtReq.Body) diff --git a/controller/handler_mgmt/validate_terminators.go b/controller/handler_mgmt/validate_terminators.go index 07023c320..4f2ec6702 100644 --- a/controller/handler_mgmt/validate_terminators.go +++ b/controller/handler_mgmt/validate_terminators.go @@ -47,7 +47,7 @@ func (handler *validateTerminatorsHandler) HandleReceive(msg *channel.Message, c var terminatorCount uint64 if err = proto.Unmarshal(msg.Body, request); err == nil { - terminatorCount, err = handler.network.Managers.Terminators.ValidateTerminators(request.Filter, request.FixInvalid, func(detail *mgmt_pb.TerminatorDetail) { + terminatorCount, err = handler.network.Managers.Terminator.ValidateTerminators(request.Filter, request.FixInvalid, func(detail *mgmt_pb.TerminatorDetail) { if !ch.IsClosed() { if sendErr := protobufs.MarshalTyped(detail).WithTimeout(15 * time.Second).SendAndWaitForWire(ch); sendErr != nil { log.WithError(sendErr).Error("send of terminator detail failed, closing channel") diff --git a/controller/internal/policy/service_policy_enforcer.go b/controller/internal/policy/service_policy_enforcer.go index 0a624f016..8afc5a9bb 100644 --- a/controller/internal/policy/service_policy_enforcer.go +++ b/controller/internal/policy/service_policy_enforcer.go @@ -77,7 +77,7 @@ func (enforcer *ServicePolicyEnforcer) handleServiceEvent(event *db.ServiceEvent var sessionsToDelete []string - err := enforcer.appEnv.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := enforcer.appEnv.GetDb().View(func(tx *bbolt.Tx) error { identity, err := enforcer.appEnv.GetStores().Identity.LoadById(tx, event.IdentityId) if err != nil { return err @@ -132,7 +132,7 @@ func (enforcer *ServicePolicyEnforcer) Run() error { } var sessionsToRemove []string - err = enforcer.appEnv.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err = enforcer.appEnv.GetDb().View(func(tx *bbolt.Tx) error { for _, session := range result.Sessions { apiSession, err := enforcer.appEnv.GetStores().ApiSession.LoadById(tx, session.ApiSessionId) if err != nil { diff --git a/controller/internal/routes/authenticate_router.go b/controller/internal/routes/authenticate_router.go index fa0fe0e4d..5efefb12e 100644 --- a/controller/internal/routes/authenticate_router.go +++ b/controller/internal/routes/authenticate_router.go @@ -222,7 +222,7 @@ func (ro *AuthRouter) authHandler(ae *env.AppEnv, rc *response.RequestContext, h env.ProcessAuthQueries(ae, rc) - apiSession := MapToCurrentApiSessionRestModel(ae, rc, ae.Config.SessionTimeoutDuration()) + apiSession := MapToCurrentApiSessionRestModel(ae, rc, ae.GetConfig().Edge.SessionTimeoutDuration()) //re-calc session headers as they were not set when ApiSession == NIL response.AddSessionHeaders(rc) diff --git a/controller/internal/routes/base_router.go b/controller/internal/routes/base_router.go index 271f75600..288285928 100644 --- a/controller/internal/routes/base_router.go +++ b/controller/internal/routes/base_router.go @@ -20,19 +20,18 @@ import ( "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_model" - edgeApiError "github.com/openziti/ziti/controller/apierror" - "github.com/openziti/ziti/controller/env" - "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/response" + "github.com/openziti/foundation/v2/errorz" + "github.com/openziti/storage/ast" + "github.com/openziti/storage/boltz" "github.com/openziti/ziti/controller/api" "github.com/openziti/ziti/controller/apierror" + edgeApiError "github.com/openziti/ziti/controller/apierror" "github.com/openziti/ziti/controller/change" + "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" - "github.com/openziti/foundation/v2/errorz" - "github.com/openziti/storage/ast" - "github.com/openziti/storage/boltz" + "github.com/openziti/ziti/controller/response" "github.com/pkg/errors" "go.etcd.io/bbolt" "net/http" @@ -490,9 +489,9 @@ func ListAssociationWithHandler[E models.Entity, A models.Entity](ae *env.AppEnv } func ListTerminatorAssociations(ae *env.AppEnv, rc *response.RequestContext, - lister models.EntityRetriever[*model.Service], - associationLoader *network.TerminatorManager, - mapper func(ae *env.AppEnv, _ *response.RequestContext, terminator *network.Terminator) (interface{}, error)) { + lister models.EntityRetriever[*model.EdgeService], + associationLoader *model.TerminatorManager, + mapper func(ae *env.AppEnv, _ *response.RequestContext, terminator *model.Terminator) (interface{}, error)) { ListAssociations(rc, func(rc *response.RequestContext, id string, queryOptions *PublicQueryOptions) (*QueryResult, error) { // validate that the submitted query is only using public symbols. The query options may contain an final // query which has been modified with additional filters @@ -501,7 +500,7 @@ func ListTerminatorAssociations(ae *env.AppEnv, rc *response.RequestContext, return nil, err } - result := models.EntityListResult[*network.Terminator]{ + result := models.EntityListResult[*model.Terminator]{ Loader: associationLoader, } err = lister.PreparedListAssociatedWithHandler(id, associationLoader.GetStore().GetEntityType(), query, result.Collect) diff --git a/controller/internal/routes/ca_router.go b/controller/internal/routes/ca_router.go index 59decd23f..bb50f20a2 100644 --- a/controller/internal/routes/ca_router.go +++ b/controller/internal/routes/ca_router.go @@ -266,7 +266,7 @@ func (r *CaRouter) generateJwt(ae *env.AppEnv, rc *response.RequestContext) { claims := &ziti.EnrollmentClaims{ EnrollmentMethod: method, RegisteredClaims: jwt.RegisteredClaims{ - Issuer: fmt.Sprintf(`https://%s/`, ae.Config.Api.Address), + Issuer: fmt.Sprintf(`https://%s/`, ae.GetConfig().Edge.Api.Address), Subject: ca.Id, ID: ca.Id, }, diff --git a/controller/internal/routes/current_api_session_router.go b/controller/internal/routes/current_api_session_router.go index 94eab91de..7b49675d0 100644 --- a/controller/internal/routes/current_api_session_router.go +++ b/controller/internal/routes/current_api_session_router.go @@ -85,7 +85,7 @@ func (router *CurrentSessionRouter) Register(ae *env.AppEnv) { } func (router *CurrentSessionRouter) Detail(ae *env.AppEnv, rc *response.RequestContext) { - apiSession := MapToCurrentApiSessionRestModel(ae, rc, ae.Config.SessionTimeoutDuration()) + apiSession := MapToCurrentApiSessionRestModel(ae, rc, ae.GetConfig().Edge.SessionTimeoutDuration()) rc.Respond(rest_model.CurrentAPISessionDetailEnvelope{Data: apiSession, Meta: &rest_model.Meta{}}, http.StatusOK) } @@ -205,7 +205,7 @@ func (nsr *ApiSessionCertificateCreateResponder) RespondWithCreatedId(id string, ID: sessionCert.Id, }, Certificate: &certString, - Cas: string(nsr.ae.Config.CaPems()), + Cas: string(nsr.ae.GetConfig().Edge.CaPems()), }, Meta: &rest_model.Meta{}, } diff --git a/controller/internal/routes/current_identity_authenticator_router.go b/controller/internal/routes/current_identity_authenticator_router.go index 014e0a4f6..b56f88e43 100644 --- a/controller/internal/routes/current_identity_authenticator_router.go +++ b/controller/internal/routes/current_identity_authenticator_router.go @@ -23,12 +23,12 @@ import ( clientCurrentApiSession "github.com/openziti/edge-api/rest_client_api_server/operations/current_api_session" managementCurrentApiSession "github.com/openziti/edge-api/rest_management_api_server/operations/current_api_session" "github.com/openziti/edge-api/rest_model" + "github.com/openziti/foundation/v2/errorz" + "github.com/openziti/storage/boltz" "github.com/openziti/ziti/controller/env" + "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/internal/permissions" "github.com/openziti/ziti/controller/response" - "github.com/openziti/ziti/controller/fields" - "github.com/openziti/foundation/v2/errorz" - "github.com/openziti/storage/boltz" ) func init() { @@ -214,7 +214,7 @@ func (r *CurrentIdentityAuthenticatorRouter) Extend(ae *env.AppEnv, rc *response } rc.RespondWithOk(&rest_model.IdentityExtendCerts{ - Ca: string(ae.Config.CaPems()), + Ca: string(ae.GetConfig().Edge.CaPems()), ClientCert: string(certPem), }, &rest_model.Meta{}) } diff --git a/controller/internal/routes/database_router.go b/controller/internal/routes/database_router.go index 7b8032828..b737d7d9b 100644 --- a/controller/internal/routes/database_router.go +++ b/controller/internal/routes/database_router.go @@ -26,8 +26,8 @@ import ( "github.com/openziti/ziti/controller/apierror" "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/internal/permissions" - "github.com/openziti/ziti/controller/response" "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/response" "net/http" "sync" "sync/atomic" @@ -176,5 +176,5 @@ func (r *DatabaseRouter) runDataIntegrityCheck(ae *env.AppEnv, rc *response.Requ } } - r.integrityCheck.err = ae.GetDbProvider().GetStores().CheckIntegrity(ae.GetDbProvider().GetDb(), rc.NewChangeContext().GetContext(), fixErrors, errorHandler) + r.integrityCheck.err = ae.GetStores().CheckIntegrity(ae.GetDb(), rc.NewChangeContext().GetContext(), fixErrors, errorHandler) } diff --git a/controller/internal/routes/enroll_router.go b/controller/internal/routes/enroll_router.go index 20e522305..27d0f6852 100644 --- a/controller/internal/routes/enroll_router.go +++ b/controller/internal/routes/enroll_router.go @@ -99,7 +99,7 @@ func (ro *EnrollRouter) getCaCerts(ae *env.AppEnv, rc *response.RequestContext) // Decode each PEM block in the input and append the ASN.1 // DER bytes for each certificate therein to the data slice. - input := ae.Config.CaPems() + input := ae.GetConfig().Edge.CaPems() var data []byte for len(input) > 0 { diff --git a/controller/internal/routes/protocol_router.go b/controller/internal/routes/protocol_router.go index d967ba65d..3b44d9d3e 100644 --- a/controller/internal/routes/protocol_router.go +++ b/controller/internal/routes/protocol_router.go @@ -49,7 +49,7 @@ func (router *ProtocolRouter) Register(ae *env.AppEnv) { func (router *ProtocolRouter) List(ae *env.AppEnv, rc *response.RequestContext) { data := rest_model.ListProtocols{ "https": rest_model.Protocol{ - Address: &ae.Config.Api.Address, + Address: &ae.GetConfig().Edge.Api.Address, }, } rc.RespondWithOk(data, &rest_model.Meta{}) diff --git a/controller/internal/routes/router_api_model.go b/controller/internal/routes/router_api_model.go index 30e92ddff..f0c41b701 100644 --- a/controller/internal/routes/router_api_model.go +++ b/controller/internal/routes/router_api_model.go @@ -5,11 +5,11 @@ import ( "github.com/go-openapi/strfmt" "github.com/openziti/edge-api/rest_model" + "github.com/openziti/foundation/v2/stringz" "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/response" "github.com/openziti/ziti/controller/models" - "github.com/openziti/foundation/v2/stringz" + "github.com/openziti/ziti/controller/response" ) const EntityNameTransitRouter = "transit-routers" diff --git a/controller/internal/routes/service_api_model.go b/controller/internal/routes/service_api_model.go index de20a82d1..427163f08 100644 --- a/controller/internal/routes/service_api_model.go +++ b/controller/internal/routes/service_api_model.go @@ -54,8 +54,8 @@ func (factory *ServiceLinkFactoryIml) Links(entity models.Entity) rest_model.Lin return links } -func MapCreateServiceToModel(service *rest_model.ServiceCreate) *model.Service { - ret := &model.Service{ +func MapCreateServiceToModel(service *rest_model.ServiceCreate) *model.EdgeService { + ret := &model.EdgeService{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), }, @@ -70,8 +70,8 @@ func MapCreateServiceToModel(service *rest_model.ServiceCreate) *model.Service { return ret } -func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *model.Service { - ret := &model.Service{ +func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *model.EdgeService { + ret := &model.EdgeService{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), Id: id, @@ -87,8 +87,8 @@ func MapUpdateServiceToModel(id string, service *rest_model.ServiceUpdate) *mode return ret } -func MapPatchServiceToModel(id string, service *rest_model.ServicePatch) *model.Service { - ret := &model.Service{ +func MapPatchServiceToModel(id string, service *rest_model.ServicePatch) *model.EdgeService { + ret := &model.EdgeService{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(service.Tags), Id: id, diff --git a/controller/internal/routes/service_router.go b/controller/internal/routes/service_router.go index abbcd84f8..2f3566b6b 100644 --- a/controller/internal/routes/service_router.go +++ b/controller/internal/routes/service_router.go @@ -264,15 +264,15 @@ func (r *ServiceRouter) Patch(ae *env.AppEnv, rc *response.RequestContext, param } func (r *ServiceRouter) listServiceEdgeRouterPolicies(ae *env.AppEnv, rc *response.RequestContext) { - ListAssociationWithHandler[*model.Service, *model.ServiceEdgeRouterPolicy](ae, rc, ae.Managers.EdgeService, ae.Managers.ServiceEdgeRouterPolicy, MapServiceEdgeRouterPolicyToRestEntity) + ListAssociationWithHandler[*model.EdgeService, *model.ServiceEdgeRouterPolicy](ae, rc, ae.Managers.EdgeService, ae.Managers.ServiceEdgeRouterPolicy, MapServiceEdgeRouterPolicyToRestEntity) } func (r *ServiceRouter) listServicePolicies(ae *env.AppEnv, rc *response.RequestContext) { - ListAssociationWithHandler[*model.Service, *model.ServicePolicy](ae, rc, ae.Managers.EdgeService, ae.Managers.ServicePolicy, MapServicePolicyToRestEntity) + ListAssociationWithHandler[*model.EdgeService, *model.ServicePolicy](ae, rc, ae.Managers.EdgeService, ae.Managers.ServicePolicy, MapServicePolicyToRestEntity) } func (r *ServiceRouter) listConfigs(ae *env.AppEnv, rc *response.RequestContext) { - ListAssociationWithHandler[*model.Service, *model.Config](ae, rc, ae.Managers.EdgeService, ae.Managers.Config, MapConfigToRestEntity) + ListAssociationWithHandler[*model.EdgeService, *model.Config](ae, rc, ae.Managers.EdgeService, ae.Managers.Config, MapConfigToRestEntity) } func (r *ServiceRouter) listManagementTerminators(ae *env.AppEnv, rc *response.RequestContext) { diff --git a/controller/internal/routes/summary_router.go b/controller/internal/routes/summary_router.go index a2e991b7a..0f2b36dda 100644 --- a/controller/internal/routes/summary_router.go +++ b/controller/internal/routes/summary_router.go @@ -48,7 +48,7 @@ func (r *SummaryRouter) Register(ae *env.AppEnv) { } func (r *SummaryRouter) List(ae *env.AppEnv, rc *response.RequestContext) { - data, err := ae.GetStores().GetEntityCounts(ae.GetDbProvider().GetDb()) + data, err := ae.GetStores().GetEntityCounts(ae.GetDb()) if err != nil { rc.RespondWithError(err) } else { diff --git a/controller/internal/routes/terminator_api_model.go b/controller/internal/routes/terminator_api_model.go index bc9d0bab3..1390d3ae8 100644 --- a/controller/internal/routes/terminator_api_model.go +++ b/controller/internal/routes/terminator_api_model.go @@ -20,21 +20,22 @@ import ( "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_model" - "github.com/openziti/ziti/controller/env" - "github.com/openziti/ziti/controller/response" + "github.com/openziti/foundation/v2/stringz" "github.com/openziti/ziti/controller/api" + "github.com/openziti/ziti/controller/env" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/response" "github.com/openziti/ziti/controller/xt" - "github.com/openziti/foundation/v2/stringz" ) const EntityNameTerminator = "terminators" var TerminatorLinkFactory = NewBasicLinkFactory(EntityNameTerminator) -func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *network.Terminator { - ret := &network.Terminator{ +func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), }, @@ -54,8 +55,8 @@ func MapCreateTerminatorToModel(terminator *rest_model.TerminatorCreate) *networ return ret } -func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpdate) *network.Terminator { - ret := &network.Terminator{ +func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpdate) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), Id: id, @@ -74,8 +75,8 @@ func MapUpdateTerminatorToModel(id string, terminator *rest_model.TerminatorUpda return ret } -func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch) *network.Terminator { - ret := &network.Terminator{ +func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch) *model.Terminator { + ret := &model.Terminator{ BaseEntity: models.BaseEntity{ Tags: TagsOrDefault(terminator.Tags), Id: id, @@ -96,7 +97,7 @@ func MapPatchTerminatorToModel(id string, terminator *rest_model.TerminatorPatch type TerminatorModelMapper struct{} -func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, terminator *network.Terminator) (interface{}, error) { +func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, terminator *model.Terminator) (interface{}, error) { restModel, err := MapTerminatorToRestModel(n, terminator) if err != nil { @@ -108,18 +109,18 @@ func (TerminatorModelMapper) ToApi(n *network.Network, _ api.RequestContext, ter return restModel, nil } -func MapTerminatorToRestEntity(ae *env.AppEnv, _ *response.RequestContext, terminator *network.Terminator) (interface{}, error) { +func MapTerminatorToRestEntity(ae *env.AppEnv, _ *response.RequestContext, terminator *model.Terminator) (interface{}, error) { return MapTerminatorToRestModel(ae.GetHostController().GetNetwork(), terminator) } -func MapTerminatorToRestModel(n *network.Network, terminator *network.Terminator) (*rest_model.TerminatorDetail, error) { +func MapTerminatorToRestModel(n *network.Network, terminator *model.Terminator) (*rest_model.TerminatorDetail, error) { - service, err := n.Managers.Services.Read(terminator.Service) + service, err := n.Managers.Service.Read(terminator.Service) if err != nil { return nil, err } - router, err := n.Managers.Routers.Read(terminator.Router) + router, err := n.Managers.Router.Read(terminator.Router) if err != nil { return nil, err } @@ -155,11 +156,11 @@ func MapTerminatorToRestModel(n *network.Network, terminator *network.Terminator return ret, nil } -func MapClientTerminatorToRestEntity(ae *env.AppEnv, _ *response.RequestContext, terminator *network.Terminator) (interface{}, error) { +func MapClientTerminatorToRestEntity(ae *env.AppEnv, _ *response.RequestContext, terminator *model.Terminator) (interface{}, error) { return MapLimitedTerminatorToRestModel(ae, terminator) } -func MapLimitedTerminatorToRestModel(ae *env.AppEnv, terminator *network.Terminator) (*rest_model.TerminatorClientDetail, error) { +func MapLimitedTerminatorToRestModel(ae *env.AppEnv, terminator *model.Terminator) (*rest_model.TerminatorClientDetail, error) { service, err := ae.Managers.EdgeService.Read(terminator.Service) if err != nil { return nil, err diff --git a/controller/internal/routes/terminator_router.go b/controller/internal/routes/terminator_router.go index c0ec80e6a..ad9925903 100644 --- a/controller/internal/routes/terminator_router.go +++ b/controller/internal/routes/terminator_router.go @@ -19,12 +19,12 @@ package routes import ( "github.com/go-openapi/runtime/middleware" "github.com/openziti/edge-api/rest_management_api_server/operations/terminator" + "github.com/openziti/ziti/controller/api_impl" "github.com/openziti/ziti/controller/env" + "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/internal/permissions" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/response" - "github.com/openziti/ziti/controller/api_impl" - "github.com/openziti/ziti/controller/fields" - "github.com/openziti/ziti/controller/network" ) func init() { @@ -69,11 +69,11 @@ func (r *TerminatorRouter) Register(ae *env.AppEnv) { } func (r *TerminatorRouter) List(ae *env.AppEnv, rc *response.RequestContext) { - api_impl.ListWithHandler[*network.Terminator](ae.GetHostController().GetNetwork(), rc, ae.Managers.Terminator, TerminatorModelMapper{}) + api_impl.ListWithHandler[*model.Terminator](ae.GetHostController().GetNetwork(), rc, ae.Managers.Terminator, TerminatorModelMapper{}) } func (r *TerminatorRouter) Detail(ae *env.AppEnv, rc *response.RequestContext) { - api_impl.DetailWithHandler[*network.Terminator](ae.GetHostController().GetNetwork(), rc, ae.Managers.Terminator, TerminatorModelMapper{}) + api_impl.DetailWithHandler[*model.Terminator](ae.GetHostController().GetNetwork(), rc, ae.Managers.Terminator, TerminatorModelMapper{}) } func (r *TerminatorRouter) Create(ae *env.AppEnv, rc *response.RequestContext, params terminator.CreateTerminatorParams) { diff --git a/controller/model/api_session_certificate_manager.go b/controller/model/api_session_certificate_manager.go index 8f125df80..2875ccfbe 100644 --- a/controller/model/api_session_certificate_manager.go +++ b/controller/model/api_session_certificate_manager.go @@ -89,8 +89,6 @@ func (self *ApiSessionCertificateManager) CreateFromCSR(apiSessionId string, lif PEM: string(certPem), } - self.env.GetHostController().GetConfig() - return self.Create(entity, ctx) } diff --git a/controller/model/api_session_manager.go b/controller/model/api_session_manager.go index 642066bb0..8c59bc4d8 100644 --- a/controller/model/api_session_manager.go +++ b/controller/model/api_session_manager.go @@ -36,7 +36,11 @@ func NewApiSessionManager(env Env) *ApiSessionManager { baseEntityManager: newBaseEntityManager[*ApiSession, *db.ApiSession](env, env.GetStores().ApiSession), } - manager.HeartbeatCollector = NewHeartbeatCollector(env, env.GetConfig().Api.ActivityUpdateBatchSize, env.GetConfig().Api.ActivityUpdateInterval, manager.heartbeatFlush) + manager.HeartbeatCollector = NewHeartbeatCollector( + env, + env.GetConfig().Edge.Api.ActivityUpdateBatchSize, + env.GetConfig().Edge.Api.ActivityUpdateInterval, + manager.heartbeatFlush) manager.impl = manager @@ -54,7 +58,7 @@ func (self *ApiSessionManager) newModelEntity() *ApiSession { func (self *ApiSessionManager) Create(ctx boltz.MutateContext, entity *ApiSession, sessionCerts []*ApiSessionCertificate) (string, error) { var apiSessionId string - err := self.env.GetDbProvider().GetDb().Update(ctx, func(ctx boltz.MutateContext) error { + err := self.env.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { var err error apiSessionId, err = self.CreateInCtx(ctx, entity, sessionCerts) return err @@ -215,7 +219,7 @@ func (self *ApiSessionManager) Stream(query string, collect func(*ApiSession, er return fmt.Errorf("could not parse query for streaming api sessions: %v", err) } - return self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + return self.env.GetDb().View(func(tx *bbolt.Tx) error { for cursor := self.Store.IterateIds(tx, filter); cursor.IsValid(); cursor.Next() { current := cursor.Current() @@ -235,7 +239,7 @@ func (self *ApiSessionManager) StreamIds(query string, collect func(string, erro return fmt.Errorf("could not parse query for streaming api sessions ids: %v", err) } - return self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + return self.env.GetDb().View(func(tx *bbolt.Tx) error { for cursor := self.Store.IterateIds(tx, filter); cursor.IsValid(); cursor.Next() { current := cursor.Current() if err := collect(string(current), err); err != nil { @@ -290,7 +294,7 @@ func (self *ApiSessionManager) VisitFingerprintsForApiSession(tx *bbolt.Tx, iden } func (self *ApiSessionManager) DeleteByIdentityId(identityId string, changeCtx *change.Context) error { - return self.GetEnv().GetDbProvider().GetDb().Update(changeCtx.NewMutateContext(), func(ctx boltz.MutateContext) error { + return self.GetEnv().GetDb().Update(changeCtx.NewMutateContext(), func(ctx boltz.MutateContext) error { query := fmt.Sprintf(`%s = "%s"`, db.FieldApiSessionIdentity, identityId) return self.Store.DeleteWhere(ctx, query) }) diff --git a/controller/model/api_session_model.go b/controller/model/api_session_model.go index 329dfa6e0..4796dd113 100644 --- a/controller/model/api_session_model.go +++ b/controller/model/api_session_model.go @@ -77,8 +77,8 @@ func (entity *ApiSession) fillFrom(env Env, tx *bbolt.Tx, boltApiSession *db.Api entity.IPAddress = boltApiSession.IPAddress entity.MfaRequired = boltApiSession.MfaRequired entity.MfaComplete = boltApiSession.MfaComplete - entity.ExpiresAt = entity.UpdatedAt.Add(env.GetConfig().Api.SessionTimeout) - entity.ExpirationDuration = env.GetConfig().Api.SessionTimeout + entity.ExpiresAt = entity.UpdatedAt.Add(env.GetConfig().Edge.Api.SessionTimeout) + entity.ExpirationDuration = env.GetConfig().Edge.Api.SessionTimeout entity.LastActivityAt = boltApiSession.LastActivityAt entity.AuthenticatorId = boltApiSession.AuthenticatorId diff --git a/controller/model/auth_policy_manager.go b/controller/model/auth_policy_manager.go index 8b0a4aaa2..73bf111f4 100644 --- a/controller/model/auth_policy_manager.go +++ b/controller/model/auth_policy_manager.go @@ -26,7 +26,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "google.golang.org/protobuf/proto" ) @@ -37,7 +36,7 @@ func NewAuthPolicyManager(env Env) *AuthPolicyManager { } manager.impl = manager - network.RegisterManagerDecoder[*AuthPolicy](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*AuthPolicy](env, manager) return manager } @@ -47,7 +46,7 @@ type AuthPolicyManager struct { } func (self *AuthPolicyManager) Create(entity *AuthPolicy, ctx *change.Context) error { - return network.DispatchCreate[*AuthPolicy](self, entity, ctx) + return DispatchCreate[*AuthPolicy](self, entity, ctx) } func (self *AuthPolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[*AuthPolicy], ctx boltz.MutateContext) error { @@ -69,7 +68,7 @@ func (self *AuthPolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[*Aut } func (self *AuthPolicyManager) Update(entity *AuthPolicy, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*AuthPolicy](self, entity, checker, ctx) + return DispatchUpdate[*AuthPolicy](self, entity, checker, ctx) } func (self *AuthPolicyManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*AuthPolicy], ctx boltz.MutateContext) error { diff --git a/controller/model/authenticator_manager.go b/controller/model/authenticator_manager.go index 5cefa3d66..dff717a41 100644 --- a/controller/model/authenticator_manager.go +++ b/controller/model/authenticator_manager.go @@ -35,7 +35,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" @@ -59,7 +58,7 @@ func NewAuthenticatorManager(env Env) *AuthenticatorManager { manager.impl = manager - network.RegisterManagerDecoder[*Authenticator](env.GetHostController().GetNetwork().GetManagers(), manager) + RegisterManagerDecoder[*Authenticator](env, manager) return manager } @@ -85,7 +84,7 @@ func (self *AuthenticatorManager) Authorize(authContext AuthContext) (AuthResult func (self *AuthenticatorManager) ReadFingerprints(authenticatorId string) ([]string, error) { var authenticator *db.Authenticator - err := self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := self.env.GetDb().View(func(tx *bbolt.Tx) error { var err error authenticator, err = self.authStore.LoadById(tx, authenticatorId) return err @@ -107,7 +106,7 @@ func (self *AuthenticatorManager) Read(id string) (*Authenticator, error) { } func (self *AuthenticatorManager) Create(entity *Authenticator, ctx *change.Context) error { - return network.DispatchCreate[*Authenticator](self, entity, ctx) + return DispatchCreate[*Authenticator](self, entity, ctx) } func (self *AuthenticatorManager) ApplyCreate(cmd *command.CreateEntityCommand[*Authenticator], ctx boltz.MutateContext) error { @@ -214,7 +213,7 @@ func (self *AuthenticatorManager) ApplyUpdate(cmd *command.UpdateEntityCommand[* func (self *AuthenticatorManager) getRootPool() *x509.CertPool { roots := x509.NewCertPool() - roots.AppendCertsFromPEM(self.env.GetConfig().CaPems()) + roots.AppendCertsFromPEM(self.env.GetConfig().Edge.CaPems()) err := self.env.GetManagers().Ca.Stream("isVerified = true", func(ca *Ca, err error) error { if err != nil { @@ -457,7 +456,7 @@ func (self *AuthenticatorManager) ExtendCertForIdentity(identityId string, authe } caPool := x509.NewCertPool() - config := self.env.GetConfig() + config := self.env.GetConfig().Edge caPool.AddCert(config.Enrollment.SigningCert.Cert().Leaf) validClientCert.NotBefore = time.Now().Add(-1 * time.Hour) @@ -602,7 +601,7 @@ func (self *AuthenticatorManager) VerifyExtendCertForIdentity(apiSessionId, iden PEM: verifyCertPem, } - return self.env.GetDbProvider().GetDb().Update(ctx.NewMutateContext(), func(mutateCtx boltz.MutateContext) error { + return self.env.GetDb().Update(ctx.NewMutateContext(), func(mutateCtx boltz.MutateContext) error { if err = self.env.GetStores().ApiSessionCertificate.Create(mutateCtx, sessionCert); err != nil { return err } @@ -675,7 +674,7 @@ func getCaId(env Env, auth *AuthenticatorCert) string { cert := certs[0] caId := "" - err := env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := env.GetDb().View(func(tx *bbolt.Tx) error { for cursor := env.GetStores().Ca.IterateIds(tx, ast.BoolNodeTrue); cursor.IsValid(); cursor.Next() { ca, err := env.GetStores().Ca.LoadById(tx, string(cursor.Current())) if err != nil { diff --git a/controller/model/authenticator_mod_ext_jwt.go b/controller/model/authenticator_mod_ext_jwt.go index 079788ed4..5b980d269 100644 --- a/controller/model/authenticator_mod_ext_jwt.go +++ b/controller/model/authenticator_mod_ext_jwt.go @@ -502,7 +502,7 @@ func (a *AuthModuleExtJwt) onExternalSignerCreate(args ...interface{}) { func (a *AuthModuleExtJwt) onExternalSignerUpdate(signer *db.ExternalJwtSigner) { //read on update because patches can pass partial data - err := a.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := a.env.GetDb().View(func(tx *bbolt.Tx) error { var err error signer, _, err = a.env.GetStores().ExternalJwtSigner.FindById(tx, signer.Id) return err @@ -559,7 +559,7 @@ func (a *AuthModuleExtJwt) onExternalSignerDelete(signer *db.ExternalJwtSigner) } func (a *AuthModuleExtJwt) loadExistingSigners() { - err := a.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := a.env.GetDb().View(func(tx *bbolt.Tx) error { ids, _, err := a.env.GetStores().ExternalJwtSigner.QueryIds(tx, "") if err != nil { diff --git a/controller/model/base_manager.go b/controller/model/base_manager.go index 72cc992b5..56f0c0738 100644 --- a/controller/model/base_manager.go +++ b/controller/model/base_manager.go @@ -18,12 +18,12 @@ package model import ( "github.com/michaelquigley/pfxlog" + "github.com/openziti/storage/ast" + "github.com/openziti/storage/boltz" "github.com/openziti/ziti/common/eid" "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/models" - "github.com/openziti/storage/ast" - "github.com/openziti/storage/boltz" "github.com/pkg/errors" "go.etcd.io/bbolt" ) @@ -55,7 +55,7 @@ type baseEntityManager[ME edgeEntity[PE], PE boltz.ExtEntity] struct { } func (self *baseEntityManager[ME, PE]) Dispatch(command command.Command) error { - return self.env.GetManagers().Command.Dispatch(command) + return self.env.GetManagers().Dispatcher.Dispatch(command) } func (self *baseEntityManager[ME, PE]) GetEntityTypeId() string { @@ -69,7 +69,7 @@ func (self *baseEntityManager[ME, PE]) GetStore() boltz.EntityStore[PE] { } func (self *baseEntityManager[ME, PE]) GetDb() boltz.Db { - return self.env.GetDbProvider().GetDb() + return self.env.GetDb() } func (self *baseEntityManager[ME, PE]) GetEnv() Env { diff --git a/controller/model/ca_manager.go b/controller/model/ca_manager.go index a8119c146..ed0d0419f 100644 --- a/controller/model/ca_manager.go +++ b/controller/model/ca_manager.go @@ -26,7 +26,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "strings" @@ -38,7 +37,7 @@ func NewCaManager(env Env) *CaManager { } manager.impl = manager - network.RegisterManagerDecoder[*Ca](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*Ca](env, manager) return manager } @@ -52,7 +51,7 @@ func (self *CaManager) newModelEntity() *Ca { } func (self *CaManager) Create(entity *Ca, ctx *change.Context) error { - return network.DispatchCreate[*Ca](self, entity, ctx) + return DispatchCreate[*Ca](self, entity, ctx) } func (self *CaManager) ApplyCreate(cmd *command.CreateEntityCommand[*Ca], ctx boltz.MutateContext) error { @@ -64,7 +63,7 @@ func (self *CaManager) Update(entity *Ca, checker fields.UpdatedFields, ctx *cha if checker != nil { checker.RemoveFields(db.FieldCaIsVerified) } - return network.DispatchUpdate[*Ca](self, entity, checker, ctx) + return DispatchUpdate[*Ca](self, entity, checker, ctx) } func (self *CaManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Ca], ctx boltz.MutateContext) error { @@ -115,7 +114,7 @@ func (self *CaManager) Verified(ca *Ca, ctx *change.Context) error { checker := &fields.UpdatedFieldsMap{ db.FieldCaIsVerified: struct{}{}, } - return network.DispatchUpdate[*Ca](self, ca, checker, ctx) + return DispatchUpdate[*Ca](self, ca, checker, ctx) } func (self *CaManager) Query(query string) (*CaListResult, error) { @@ -133,7 +132,7 @@ func (self *CaManager) Stream(query string, collect func(*Ca, error) error) erro return fmt.Errorf("could not parse query for streaming cas: %v", err) } - return self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + return self.env.GetDb().View(func(tx *bbolt.Tx) error { for cursor := self.Store.IterateIds(tx, filter); cursor.IsValid(); cursor.Next() { current := cursor.Current() diff --git a/controller/network/circuit.go b/controller/model/circuit.go similarity index 79% rename from controller/network/circuit.go rename to controller/model/circuit.go index 9362b29dd..cb63aea00 100644 --- a/controller/network/circuit.go +++ b/controller/model/circuit.go @@ -14,13 +14,13 @@ limitations under the License. */ -package network +package model import ( "github.com/openziti/identity" "github.com/openziti/storage/objectz" + "github.com/openziti/ziti/common/datastructures" "github.com/openziti/ziti/common/logcontext" - "github.com/openziti/ziti/controller/idgen" "github.com/openziti/ziti/controller/xt" "github.com/orcaman/concurrent-map/v2" "sync/atomic" @@ -68,10 +68,6 @@ func (self *Circuit) IsSystemEntity() bool { return false } -func (self *Circuit) cost(minRouterCost uint16) int64 { - return self.Path.cost(minRouterCost) -} - func (self *Circuit) HasRouter(routerId string) bool { if self == nil || self.Path == nil { return false @@ -91,19 +87,17 @@ func (self *Circuit) IsEndpointRouter(routerId string) bool { return self.Path.Nodes[0].Id == routerId || self.Path.Nodes[len(self.Path.Nodes)-1].Id == routerId } -type circuitController struct { - circuits cmap.ConcurrentMap[string, *Circuit] - idGenerator idgen.Generator - store *objectz.ObjectStore[*Circuit] +type CircuitManager struct { + circuits cmap.ConcurrentMap[string, *Circuit] + store *objectz.ObjectStore[*Circuit] } -func newCircuitController() *circuitController { - result := &circuitController{ - circuits: cmap.New[*Circuit](), - idGenerator: idgen.NewGenerator(), +func NewCircuitController() *CircuitManager { + result := &CircuitManager{ + circuits: cmap.New[*Circuit](), } result.store = objectz.NewObjectStore[*Circuit](func() objectz.ObjectIterator[*Circuit] { - return IterateCMap(result.circuits) + return datastructures.IterateCMap(result.circuits) }) result.store.AddStringSymbol("id", func(entity *Circuit) *string { return &entity.Id @@ -127,22 +121,22 @@ func newCircuitController() *circuitController { return result } -func (self *circuitController) nextCircuitId() (string, error) { - return self.idGenerator.NextAlphaNumericPrefixedId() +func (self *CircuitManager) GetStore() *objectz.ObjectStore[*Circuit] { + return self.store } -func (self *circuitController) add(circuit *Circuit) { +func (self *CircuitManager) Add(circuit *Circuit) { self.circuits.Set(circuit.Id, circuit) } -func (self *circuitController) get(id string) (*Circuit, bool) { +func (self *CircuitManager) Get(id string) (*Circuit, bool) { if circuit, found := self.circuits.Get(id); found { return circuit, true } return nil, false } -func (self *circuitController) all() []*Circuit { +func (self *CircuitManager) All() []*Circuit { var circuits []*Circuit self.circuits.IterCb(func(_ string, circuit *Circuit) { circuits = append(circuits, circuit) @@ -150,7 +144,7 @@ func (self *circuitController) all() []*Circuit { return circuits } -func (self *circuitController) remove(circuit *Circuit) { +func (self *CircuitManager) Remove(circuit *Circuit) { self.circuits.Remove(circuit.Id) } diff --git a/controller/network/command.go b/controller/model/command.go similarity index 52% rename from controller/network/command.go rename to controller/model/command.go index 591d5b486..4390cca21 100644 --- a/controller/network/command.go +++ b/controller/model/command.go @@ -14,26 +14,32 @@ limitations under the License. */ -package network +package model import ( + "github.com/openziti/ziti/common/pb/cmd_pb" + "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" + "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/idgen" "github.com/openziti/ziti/controller/ioc" - "github.com/openziti/ziti/common/pb/cmd_pb" + "github.com/openziti/ziti/controller/models" "google.golang.org/protobuf/proto" ) -func newCommandManager(managers *Managers) *CommandManager { +func newCommandManager(env Env, registry ioc.Registry) *CommandManager { command.GetDefaultDecoders().Clear() result := &CommandManager{ - Managers: managers, + env: env, + registry: registry, Decoders: command.GetDefaultDecoders(), } return result } type CommandManager struct { - *Managers + env Env + registry ioc.Registry Decoders command.Decoders } @@ -41,7 +47,6 @@ func (self *CommandManager) registerGenericCommands() { self.Decoders.RegisterF(int32(cmd_pb.CommandType_CreateEntityType), self.decodeCreateEntityCommand) self.Decoders.RegisterF(int32(cmd_pb.CommandType_UpdateEntityType), self.decodeUpdateEntityCommand) self.Decoders.RegisterF(int32(cmd_pb.CommandType_DeleteEntityType), self.decodeDeleteEntityCommand) - self.Decoders.RegisterF(int32(cmd_pb.CommandType_SyncSnapshot), self.decodeSyncSnapshotCommand) } func (self *CommandManager) decodeCreateEntityCommand(_ int32, data []byte) (command.Command, error) { @@ -50,7 +55,7 @@ func (self *CommandManager) decodeCreateEntityCommand(_ int32, data []byte) (com return nil, err } - decoder, err := ioc.Get[createDecoderF](self.Registry, msg.EntityType+CreateDecoder) + decoder, err := ioc.Get[createDecoderF](self.registry, msg.EntityType+CreateDecoder) if err != nil { return nil, err } @@ -64,7 +69,7 @@ func (self *CommandManager) decodeUpdateEntityCommand(_ int32, data []byte) (com return nil, err } - decoder, err := ioc.Get[updateDecoderF](self.Registry, msg.EntityType+UpdateDecoder) + decoder, err := ioc.Get[updateDecoderF](self.registry, msg.EntityType+UpdateDecoder) if err != nil { return nil, err } @@ -78,7 +83,7 @@ func (self *CommandManager) decodeDeleteEntityCommand(_ int32, data []byte) (com return nil, err } - decoder, err := ioc.Get[deleteDecoderF](self.Registry, msg.EntityType+DeleteDecoder) + decoder, err := ioc.Get[deleteDecoderF](self.registry, msg.EntityType+DeleteDecoder) if err != nil { return nil, err } @@ -86,21 +91,6 @@ func (self *CommandManager) decodeDeleteEntityCommand(_ int32, data []byte) (com return decoder(msg) } -func (self *CommandManager) decodeSyncSnapshotCommand(_ int32, data []byte) (command.Command, error) { - msg := &cmd_pb.SyncSnapshotCommand{} - if err := proto.Unmarshal(data, msg); err != nil { - return nil, err - } - - cmd := &command.SyncSnapshotCommand{ - SnapshotId: msg.SnapshotId, - Snapshot: msg.Snapshot, - SnapshotSink: self.network.RestoreSnapshot, - } - - return cmd, nil -} - // CommandMsg is a TypedMessage which is also a pointer type. // // T is message type. We want to enforce that the TypeMessage implementation is a pointer type @@ -110,39 +100,38 @@ type CommandMsg[T any] interface { *T } -// decodableCommand is a Command which knows how to decode itself from the given message type -// -// T is the type of the command. We want to enforce that the command is a pointer type so we can -// use new(T) to create new instances of it -// M is the message type that the command can use to set its internals -type decodableCommand[T any, M any] interface { - command.Command - Decode(n *Network, msg M) error - *T +type creator[T models.Entity] interface { + command.EntityCreator[T] + Dispatch(cmd command.Command) error } -// RegisterCommand register a decoder for the given command and message pair -// MT is the message type (ex: cmd_pb.CreateServiceCommand) -// CT is the command type (ex: CreateServiceCommand) -// M is the CommandMsg/command.TypedMessage implementation (ex: *cmd_pb.CreateServiceCommand) -// C is the decodableCommand/command.Command implementation (ex: *CreateServiceCommand) -// -// We only have both types specified so that we can enforce that each is a pointer type. If didn't -// enforce that the instances were pointer types, we couldn't use new to instantiate new instances. -func RegisterCommand[MT any, CT any, M CommandMsg[MT], C decodableCommand[CT, M]](managers *Managers, _ C, _ M) { - decoder := func(commandType int32, data []byte) (command.Command, error) { - var msg M = new(MT) - if err := proto.Unmarshal(data, msg); err != nil { - return nil, err - } - - cmd := C(new(CT)) - if err := cmd.Decode(managers.network, msg); err != nil { - return nil, err - } - return cmd, nil +type updater[T models.Entity] interface { + command.EntityUpdater[T] + Dispatch(cmd command.Command) error +} + +func DispatchCreate[T models.Entity](c creator[T], entity T, ctx *change.Context) error { + if entity.GetId() == "" { + id := idgen.NewUUIDString() + entity.SetId(id) + } + + cmd := &command.CreateEntityCommand[T]{ + Context: ctx, + Creator: c, + Entity: entity, + } + + return c.Dispatch(cmd) +} + +func DispatchUpdate[T models.Entity](u updater[T], entity T, updatedFields fields.UpdatedFields, ctx *change.Context) error { + cmd := &command.UpdateEntityCommand[T]{ + Context: ctx, + Updater: u, + Entity: entity, + UpdatedFields: updatedFields, } - var msg M = new(MT) - managers.Command.Decoders.RegisterF(msg.GetCommandType(), decoder) + return u.Dispatch(cmd) } diff --git a/controller/network/command_test.go b/controller/model/command_test.go similarity index 70% rename from controller/network/command_test.go rename to controller/model/command_test.go index a91be620f..ef2400565 100644 --- a/controller/network/command_test.go +++ b/controller/model/command_test.go @@ -1,27 +1,20 @@ -package network +package model import ( "testing" + "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/openziti/ziti/controller/command" - "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/stretchr/testify/require" ) func TestProtobufFactory(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := NewTestContext(t) defer ctx.Cleanup() req := require.New(t) - config := newTestConfig(ctx) - defer close(config.closeNotify) - - n, err := NewNetwork(config) - req.NoError(err) - service := &Service{ BaseEntity: models.BaseEntity{ Id: "one", @@ -31,14 +24,14 @@ func TestProtobufFactory(t *testing.T) { } createCmd := &command.CreateEntityCommand[*Service]{ - Creator: n.Managers.Services, + Creator: ctx.GetManagers().Service, Entity: service, } b, err := createCmd.Encode() req.NoError(err) - val, err := n.Managers.Command.Decoders.Decode(b) + val, err := ctx.GetManagers().Command.Decoders.Decode(b) req.NoError(err) msg, ok := val.(*command.CreateEntityCommand[*Service]) req.True(ok) @@ -48,17 +41,11 @@ func TestProtobufFactory(t *testing.T) { } func BenchmarkRegisterCommand(t *testing.B) { - ctx := db.NewTestContext(t) + ctx := NewTestContext(t) defer ctx.Cleanup() req := require.New(t) - config := newTestConfig(ctx) - defer close(config.closeNotify) - - n, err := NewNetwork(config) - req.NoError(err) - service := &Service{ BaseEntity: models.BaseEntity{ Id: "one", @@ -68,7 +55,7 @@ func BenchmarkRegisterCommand(t *testing.B) { } createCmd := &command.CreateEntityCommand[*Service]{ - Creator: n.Managers.Services, + Creator: ctx.GetManagers().Service, Entity: service, } @@ -76,7 +63,7 @@ func BenchmarkRegisterCommand(t *testing.B) { req.NoError(err) cmdType := int32(cmd_pb.CommandType_CreateEntityType) - decoder := n.Managers.Command.Decoders.GetDecoder(cmdType) + decoder := ctx.GetManagers().Command.Decoders.GetDecoder(cmdType) for i := 0; i < t.N; i++ { _, err = decoder.Decode(cmdType, b) diff --git a/controller/model/config_manager.go b/controller/model/config_manager.go index 277e93b5e..2745d4653 100644 --- a/controller/model/config_manager.go +++ b/controller/model/config_manager.go @@ -25,7 +25,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "strings" @@ -37,7 +36,7 @@ func NewConfigManager(env Env) *ConfigManager { } manager.impl = manager - network.RegisterManagerDecoder[*Config](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*Config](env, manager) return manager } @@ -51,7 +50,7 @@ func (self *ConfigManager) newModelEntity() *Config { } func (self *ConfigManager) Create(entity *Config, ctx *change.Context) error { - return network.DispatchCreate[*Config](self, entity, ctx) + return DispatchCreate[*Config](self, entity, ctx) } func (self *ConfigManager) ApplyCreate(cmd *command.CreateEntityCommand[*Config], ctx boltz.MutateContext) error { @@ -60,7 +59,7 @@ func (self *ConfigManager) ApplyCreate(cmd *command.CreateEntityCommand[*Config] } func (self *ConfigManager) Update(entity *Config, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*Config](self, entity, checker, ctx) + return DispatchUpdate[*Config](self, entity, checker, ctx) } func (self *ConfigManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Config], ctx boltz.MutateContext) error { diff --git a/controller/model/config_type_manager.go b/controller/model/config_type_manager.go index 78a2599c9..b565c43ea 100644 --- a/controller/model/config_type_manager.go +++ b/controller/model/config_type_manager.go @@ -27,7 +27,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" ) @@ -42,7 +41,7 @@ func NewConfigTypeManager(env Env) *ConfigTypeManager { } manager.impl = manager - network.RegisterManagerDecoder[*ConfigType](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*ConfigType](env, manager) return manager } @@ -56,7 +55,7 @@ func (self *ConfigTypeManager) newModelEntity() *ConfigType { } func (self *ConfigTypeManager) Create(entity *ConfigType, ctx *change.Context) error { - return network.DispatchCreate[*ConfigType](self, entity, ctx) + return DispatchCreate[*ConfigType](self, entity, ctx) } func (self *ConfigTypeManager) ApplyCreate(cmd *command.CreateEntityCommand[*ConfigType], ctx boltz.MutateContext) error { @@ -65,7 +64,7 @@ func (self *ConfigTypeManager) ApplyCreate(cmd *command.CreateEntityCommand[*Con } func (self *ConfigTypeManager) Update(entity *ConfigType, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*ConfigType](self, entity, checker, ctx) + return DispatchUpdate[*ConfigType](self, entity, checker, ctx) } func (self *ConfigTypeManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*ConfigType], ctx boltz.MutateContext) error { diff --git a/controller/model/controller_manager.go b/controller/model/controller_manager.go index 56d5e274e..294568a77 100644 --- a/controller/model/controller_manager.go +++ b/controller/model/controller_manager.go @@ -29,7 +29,6 @@ import ( "github.com/openziti/ziti/controller/event" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "time" @@ -41,7 +40,7 @@ func NewControllerManager(env Env) *ControllerManager { } manager.impl = manager - network.RegisterManagerDecoder[*Controller](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*Controller](env, manager) return manager } @@ -55,7 +54,7 @@ func (self *ControllerManager) newModelEntity() *Controller { } func (self *ControllerManager) Create(entity *Controller, ctx *change.Context) error { - return network.DispatchCreate[*Controller](self, entity, ctx) + return DispatchCreate[*Controller](self, entity, ctx) } func (self *ControllerManager) ApplyCreate(cmd *command.CreateEntityCommand[*Controller], ctx boltz.MutateContext) error { @@ -64,7 +63,7 @@ func (self *ControllerManager) ApplyCreate(cmd *command.CreateEntityCommand[*Con } func (self *ControllerManager) Update(entity *Controller, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*Controller](self, entity, checker, ctx) + return DispatchUpdate[*Controller](self, entity, checker, ctx) } func (self *ControllerManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Controller], ctx boltz.MutateContext) error { @@ -176,7 +175,7 @@ func (self *ControllerManager) Unmarshall(bytes []byte) (*Controller, error) { } func (self *ControllerManager) getCurrentAsClusterPeer() *event.ClusterPeer { - addr, id, version := self.env.GetHostController().GetRaftInfo() + addr, id, version := self.env.GetRaftInfo() tlsConfig, _, _ := self.env.GetServerCert() var leaderCerts []*x509.Certificate @@ -188,7 +187,7 @@ func (self *ControllerManager) getCurrentAsClusterPeer() *event.ClusterPeer { } } - apiAddresses, _ := self.env.GetHostController().GetApiAddresses() + apiAddresses, _ := self.env.GetApiAddresses() return &event.ClusterPeer{ Id: id, diff --git a/controller/model/create_terminator_cmd.go b/controller/model/create_terminator_cmd.go index cb6415def..b2c884fcc 100644 --- a/controller/model/create_terminator_cmd.go +++ b/controller/model/create_terminator_cmd.go @@ -9,7 +9,6 @@ import ( "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/db" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "github.com/sirupsen/logrus" "go.etcd.io/bbolt" @@ -18,12 +17,12 @@ import ( type CreateEdgeTerminatorCmd struct { Env Env - Entity *network.Terminator + Entity *Terminator Context *change.Context } func (self *CreateEdgeTerminatorCmd) Apply(ctx boltz.MutateContext) error { - createCmd := &command.CreateEntityCommand[*network.Terminator]{ + createCmd := &command.CreateEntityCommand[*Terminator]{ Creator: self.Env.GetManagers().Terminator, Entity: self.Entity, PostCreateHook: self.validateTerminatorIdentity, @@ -32,7 +31,7 @@ func (self *CreateEdgeTerminatorCmd) Apply(ctx boltz.MutateContext) error { return self.Env.GetManagers().Terminator.ApplyCreate(createCmd, ctx) } -func (self *CreateEdgeTerminatorCmd) validateTerminatorIdentity(ctx boltz.MutateContext, terminator *network.Terminator) error { +func (self *CreateEdgeTerminatorCmd) validateTerminatorIdentity(ctx boltz.MutateContext, terminator *Terminator) error { tx := ctx.Tx() if terminator.GetInstanceId() == "" { diff --git a/controller/model/edge_router_manager.go b/controller/model/edge_router_manager.go index 57e035b1b..cc49c3457 100644 --- a/controller/model/edge_router_manager.go +++ b/controller/model/edge_router_manager.go @@ -26,7 +26,6 @@ import ( "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/fields" - "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" "strconv" @@ -56,8 +55,8 @@ func NewEdgeRouterManager(env Env) *EdgeRouterManager { manager.impl = manager RegisterCommand(env, &CreateEdgeRouterCmd{}, &edge_cmd_pb.CreateEdgeRouterCmd{}) - network.RegisterUpdateDecoder[*EdgeRouter](env.GetHostController().GetNetwork().Managers, manager) - network.RegisterDeleteDecoder(env.GetHostController().GetNetwork().Managers, manager) + RegisterUpdateDecoder[*EdgeRouter](env, manager) + RegisterDeleteDecoder(env, manager) return manager } @@ -192,7 +191,7 @@ func (self *EdgeRouterManager) Query(query string) (*EdgeRouterListResult, error func (self *EdgeRouterManager) ListForIdentityAndService(identityId, serviceId string, limit *int) (*EdgeRouterListResult, error) { var list *EdgeRouterListResult var err error - if txErr := self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + if txErr := self.env.GetDb().View(func(tx *bbolt.Tx) error { list, err = self.ListForIdentityAndServiceWithTx(tx, identityId, serviceId, limit) return nil }); txErr != nil { diff --git a/controller/model/edge_router_manager_test.go b/controller/model/edge_router_manager_test.go index fc24e11d9..4a63319a6 100644 --- a/controller/model/edge_router_manager_test.go +++ b/controller/model/edge_router_manager_test.go @@ -49,7 +49,7 @@ func (ctx *TestContext) testGetEdgeRoutersForServiceAndIdentity(*testing.T) { func (ctx *TestContext) isEdgeRouterAccessible(edgeRouterId, identityId, serviceId string) bool { found := false - err := ctx.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := ctx.GetDb().View(func(tx *bbolt.Tx) error { result, err := ctx.managers.EdgeRouter.ListForIdentityAndServiceWithTx(tx, identityId, serviceId, nil) if err != nil { return err diff --git a/controller/model/edge_router_policy_manager.go b/controller/model/edge_router_policy_manager.go index 412562d00..a0599d501 100644 --- a/controller/model/edge_router_policy_manager.go +++ b/controller/model/edge_router_policy_manager.go @@ -24,7 +24,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) @@ -34,7 +33,7 @@ func NewEdgeRouterPolicyManager(env Env) *EdgeRouterPolicyManager { } manager.impl = manager - network.RegisterManagerDecoder[*EdgeRouterPolicy](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*EdgeRouterPolicy](env, manager) return manager } @@ -48,7 +47,7 @@ func (self *EdgeRouterPolicyManager) newModelEntity() *EdgeRouterPolicy { } func (self *EdgeRouterPolicyManager) Create(entity *EdgeRouterPolicy, ctx *change.Context) error { - return network.DispatchCreate[*EdgeRouterPolicy](self, entity, ctx) + return DispatchCreate[*EdgeRouterPolicy](self, entity, ctx) } func (self *EdgeRouterPolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[*EdgeRouterPolicy], ctx boltz.MutateContext) error { @@ -57,7 +56,7 @@ func (self *EdgeRouterPolicyManager) ApplyCreate(cmd *command.CreateEntityComman } func (self *EdgeRouterPolicyManager) Update(entity *EdgeRouterPolicy, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*EdgeRouterPolicy](self, entity, checker, ctx) + return DispatchUpdate[*EdgeRouterPolicy](self, entity, checker, ctx) } func (self *EdgeRouterPolicyManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*EdgeRouterPolicy], ctx boltz.MutateContext) error { diff --git a/controller/model/edge_service_manager.go b/controller/model/edge_service_manager.go index 2acd96590..2ac2e93c7 100644 --- a/controller/model/edge_service_manager.go +++ b/controller/model/edge_service_manager.go @@ -26,7 +26,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "time" @@ -34,19 +33,19 @@ import ( func NewEdgeServiceManager(env Env) *EdgeServiceManager { manager := &EdgeServiceManager{ - baseEntityManager: newBaseEntityManager[*Service, *db.EdgeService](env, env.GetStores().EdgeService), + baseEntityManager: newBaseEntityManager[*EdgeService, *db.EdgeService](env, env.GetStores().EdgeService), detailLister: &ServiceDetailLister{}, } manager.impl = manager manager.detailLister.manager = manager - network.RegisterManagerDecoder[*Service](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*EdgeService](env, manager) return manager } type EdgeServiceManager struct { - baseEntityManager[*Service, *db.EdgeService] + baseEntityManager[*EdgeService, *db.EdgeService] detailLister *ServiceDetailLister } @@ -58,32 +57,32 @@ func (self *EdgeServiceManager) GetEntityTypeId() string { return "edgeServices" } -func (self *EdgeServiceManager) newModelEntity() *Service { - return &Service{} +func (self *EdgeServiceManager) newModelEntity() *EdgeService { + return &EdgeService{} } -func (self *EdgeServiceManager) Create(entity *Service, ctx *change.Context) error { - return network.DispatchCreate[*Service](self, entity, ctx) +func (self *EdgeServiceManager) Create(entity *EdgeService, ctx *change.Context) error { + return DispatchCreate[*EdgeService](self, entity, ctx) } -func (self *EdgeServiceManager) ApplyCreate(cmd *command.CreateEntityCommand[*Service], ctx boltz.MutateContext) error { +func (self *EdgeServiceManager) ApplyCreate(cmd *command.CreateEntityCommand[*EdgeService], ctx boltz.MutateContext) error { _, err := self.createEntity(cmd.Entity, ctx) return err } -func (self *EdgeServiceManager) Update(entity *Service, checker fields.UpdatedFields, ctx *change.Context) error { +func (self *EdgeServiceManager) Update(entity *EdgeService, checker fields.UpdatedFields, ctx *change.Context) error { if checker != nil { checker = checker.RemoveFields("encryptionRequired") } - return network.DispatchUpdate[*Service](self, entity, checker, ctx) + return DispatchUpdate[*EdgeService](self, entity, checker, ctx) } -func (self *EdgeServiceManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Service], ctx boltz.MutateContext) error { +func (self *EdgeServiceManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*EdgeService], ctx boltz.MutateContext) error { return self.updateEntity(cmd.Entity, cmd.UpdatedFields, ctx) } -func (self *EdgeServiceManager) ReadByName(name string) (*Service, error) { - entity := &Service{} +func (self *EdgeServiceManager) ReadByName(name string) (*EdgeService, error) { + entity := &EdgeService{} nameIndex := self.env.GetStores().EdgeService.GetNameIndex() if err := self.readEntityWithIndex("name", []byte(name), nameIndex, entity); err != nil { return nil, err @@ -193,7 +192,7 @@ func (self *EdgeServiceManager) QueryRoleAttributes(queryString string) ([]strin return self.queryRoleAttributes(index, queryString) } -func (self *EdgeServiceManager) Marshall(entity *Service) ([]byte, error) { +func (self *EdgeServiceManager) Marshall(entity *EdgeService) ([]byte, error) { tags, err := edge_cmd_pb.EncodeTags(entity.Tags) if err != nil { return nil, err @@ -213,13 +212,13 @@ func (self *EdgeServiceManager) Marshall(entity *Service) ([]byte, error) { return proto.Marshal(msg) } -func (self *EdgeServiceManager) Unmarshall(bytes []byte) (*Service, error) { +func (self *EdgeServiceManager) Unmarshall(bytes []byte) (*EdgeService, error) { msg := &edge_cmd_pb.Service{} if err := proto.Unmarshal(bytes, msg); err != nil { return nil, err } - return &Service{ + return &EdgeService{ BaseEntity: models.BaseEntity{ Id: msg.Id, Tags: edge_cmd_pb.DecodeTags(msg.Tags), diff --git a/controller/model/edge_service_model.go b/controller/model/edge_service_model.go index b3e411689..1ae89a370 100644 --- a/controller/model/edge_service_model.go +++ b/controller/model/edge_service_model.go @@ -26,7 +26,7 @@ import ( "time" ) -type Service struct { +type EdgeService struct { models.BaseEntity Name string `json:"name"` MaxIdleTime time.Duration `json:"maxIdleTime"` @@ -36,7 +36,7 @@ type Service struct { EncryptionRequired bool `json:"encryptionRequired"` } -func (entity *Service) toBoltEntity(tx *bbolt.Tx, env Env) (*db.EdgeService, error) { +func (entity *EdgeService) toBoltEntity(tx *bbolt.Tx, env Env) (*db.EdgeService, error) { if err := entity.validateConfigs(tx, env); err != nil { return nil, err } @@ -55,11 +55,11 @@ func (entity *Service) toBoltEntity(tx *bbolt.Tx, env Env) (*db.EdgeService, err return edgeService, nil } -func (entity *Service) toBoltEntityForCreate(tx *bbolt.Tx, env Env) (*db.EdgeService, error) { +func (entity *EdgeService) toBoltEntityForCreate(tx *bbolt.Tx, env Env) (*db.EdgeService, error) { return entity.toBoltEntity(tx, env) } -func (entity *Service) validateConfigs(tx *bbolt.Tx, env Env) error { +func (entity *EdgeService) validateConfigs(tx *bbolt.Tx, env Env) error { typeMap := map[string]*db.Config{} configStore := env.GetStores().Config for _, id := range entity.Configs { @@ -82,11 +82,11 @@ func (entity *Service) validateConfigs(tx *bbolt.Tx, env Env) error { return nil } -func (entity *Service) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.EdgeService, error) { +func (entity *EdgeService) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.EdgeService, error) { return entity.toBoltEntity(tx, env) } -func (entity *Service) fillFrom(_ Env, _ *bbolt.Tx, boltService *db.EdgeService) error { +func (entity *EdgeService) fillFrom(_ Env, _ *bbolt.Tx, boltService *db.EdgeService) error { entity.FillCommon(boltService) entity.Name = boltService.Name entity.TerminatorStrategy = boltService.TerminatorStrategy diff --git a/controller/model/enrollment_manager.go b/controller/model/enrollment_manager.go index 68a9b97dc..61a817ab1 100644 --- a/controller/model/enrollment_manager.go +++ b/controller/model/enrollment_manager.go @@ -33,7 +33,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" @@ -53,7 +52,7 @@ func NewEnrollmentManager(env Env) *EnrollmentManager { manager.impl = manager - network.RegisterManagerDecoder[*Enrollment](env.GetHostController().GetNetwork().GetManagers(), manager) + RegisterManagerDecoder[*Enrollment](env, manager) RegisterCommand(env, &ReplaceEnrollmentWithAuthenticatorCmd{}, &edge_cmd_pb.ReplaceEnrollmentWithAuthenticatorCmd{}) RegisterCommand(env, &ReEnrollEdgeRouterCmd{}, &edge_cmd_pb.ReEnrollEdgeRouterCmd{}) @@ -61,7 +60,7 @@ func NewEnrollmentManager(env Env) *EnrollmentManager { } func (self *EnrollmentManager) Create(entity *Enrollment, ctx *change.Context) error { - return network.DispatchCreate[*Enrollment](self, entity, ctx) + return DispatchCreate[*Enrollment](self, entity, ctx) } func (self *EnrollmentManager) ApplyCreate(cmd *command.CreateEntityCommand[*Enrollment], ctx boltz.MutateContext) error { @@ -130,7 +129,7 @@ func (self *EnrollmentManager) ApplyCreate(cmd *command.CreateEntityCommand[*Enr } func (self *EnrollmentManager) Update(entity *Enrollment, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*Enrollment](self, entity, checker, ctx) + return DispatchUpdate[*Enrollment](self, entity, checker, ctx) } func (self *EnrollmentManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Enrollment], ctx boltz.MutateContext) error { @@ -185,7 +184,7 @@ func (self *EnrollmentManager) Enroll(ctx EnrollmentContext) (*EnrollmentResult, func (self *EnrollmentManager) ReadByToken(token string) (*Enrollment, error) { enrollment := &Enrollment{} - err := self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := self.env.GetDb().View(func(tx *bbolt.Tx) error { boltEntity, err := self.env.GetStores().Enrollment.LoadOneByToken(tx, token) if err != nil { @@ -227,7 +226,7 @@ func (self *EnrollmentManager) GetCertChainPem(certRaw []byte) (string, error) { var targetChainPem []byte - pool := identity.NewCaPool(self.env.GetConfig().CaCerts()) + pool := identity.NewCaPool(self.env.GetConfig().Edge.CaCerts()) targetChain := pool.GetChainMinusRoot(targetCert) for _, c := range targetChain { @@ -242,7 +241,7 @@ func (self *EnrollmentManager) GetCertChainPem(certRaw []byte) (string, error) { } func (self *EnrollmentManager) ApplyReplaceEncoderWithAuthenticatorCommand(cmd *ReplaceEnrollmentWithAuthenticatorCmd, ctx boltz.MutateContext) error { - return self.env.GetDbProvider().GetDb().Update(ctx, func(ctx boltz.MutateContext) error { + return self.env.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { err := self.env.GetStores().Enrollment.DeleteById(ctx, cmd.enrollmentId) if err != nil { return err diff --git a/controller/model/enrollment_mod_erott.go b/controller/model/enrollment_mod_erott.go index 98ebb097b..09ed6c860 100644 --- a/controller/model/enrollment_mod_erott.go +++ b/controller/model/enrollment_mod_erott.go @@ -140,7 +140,7 @@ func (module *EnrollModuleEr) Process(context EnrollmentContext) (*EnrollmentRes } content := &rest_model.EnrollmentCerts{ - Ca: string(module.env.GetConfig().CaPems()), + Ca: string(module.env.GetConfig().Edge.CaPems()), Cert: clientChainPem, ServerCert: string(serverCertPem), } diff --git a/controller/model/enrollment_mod_trott.go b/controller/model/enrollment_mod_trott.go index d6a1f8017..39e88cb90 100644 --- a/controller/model/enrollment_mod_trott.go +++ b/controller/model/enrollment_mod_trott.go @@ -148,7 +148,7 @@ func (module *EnrollModuleRouterOtt) Process(context EnrollmentContext) (*Enroll } content := &rest_model.EnrollmentCerts{ - Ca: string(module.env.GetConfig().CaPems()), + Ca: string(module.env.GetConfig().Edge.CaPems()), Cert: clientChainPem, ServerCert: string(srvPem), } diff --git a/controller/model/enrollment_model.go b/controller/model/enrollment_model.go index b8a6f50cf..dac5395fc 100644 --- a/controller/model/enrollment_model.go +++ b/controller/model/enrollment_model.go @@ -44,7 +44,7 @@ type Enrollment struct { } func (entity *Enrollment) FillJwtInfo(env Env, subject string) error { - expiresAt := time.Now().Add(env.GetConfig().Enrollment.EdgeIdentity.Duration).UTC() + expiresAt := time.Now().Add(env.GetConfig().Edge.Enrollment.EdgeIdentity.Duration).UTC() return entity.FillJwtInfoWithExpiresAt(env, subject, expiresAt) } @@ -72,7 +72,7 @@ func (entity *Enrollment) FillJwtInfoWithExpiresAt(env Env, subject string, expi Audience: []string{""}, ExpiresAt: &jwt.NumericDate{Time: expiresAt}, ID: entity.Token, - Issuer: fmt.Sprintf("https://%s", env.GetConfig().Api.Address), + Issuer: fmt.Sprintf("https://%s", env.GetConfig().Edge.Api.Address), Subject: subject, }, } diff --git a/controller/model/env.go b/controller/model/env.go index 6ec0c3315..c77ad9148 100644 --- a/controller/model/env.go +++ b/controller/model/env.go @@ -20,30 +20,28 @@ import ( "crypto/tls" "crypto/x509" "github.com/golang-jwt/jwt/v5" - "github.com/openziti/identity" "github.com/openziti/metrics" + "github.com/openziti/storage/boltz" "github.com/openziti/ziti/common" "github.com/openziti/ziti/common/cert" - "github.com/openziti/ziti/controller" + "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/event" "github.com/openziti/ziti/controller/jwtsigner" - "github.com/openziti/ziti/controller/network" - "github.com/xeipuuv/gojsonschema" ) type Env interface { + GetCommandDispatcher() command.Dispatcher GetManagers() *Managers GetConfig() *config.Config - GetDbProvider() network.DbProvider + GetDb() boltz.Db GetStores() *db.Stores GetAuthRegistry() AuthRegistry GetEnrollRegistry() EnrollmentRegistry GetApiClientCsrSigner() cert.Signer GetApiServerCsrSigner() cert.Signer GetControlClientCsrSigner() cert.Signer - GetHostController() HostController IsEdgeRouterOnline(id string) bool GetMetricsRegistry() metrics.Registry GetFingerprintGenerator() cert.FingerprintGenerator @@ -59,22 +57,9 @@ type Env interface { OidcIssuer() string RootIssuer() string -} -type HostController interface { - GetNetwork() *network.Network - Shutdown() - GetCloseNotifyChannel() <-chan struct{} - IsRaftEnabled() bool - Identity() identity.Identity - GetPeerSigners() []*x509.Certificate - GetRaftIndex() uint64 GetRaftInfo() (string, string, string) GetApiAddresses() (map[string][]event.ApiAddress, []byte) - GetConfig() *controller.Config -} - -type Schemas interface { - GetEnrollErPost() *gojsonschema.Schema - GetEnrollUpdbPost() *gojsonschema.Schema + GetCloseNotifyChannel() <-chan struct{} + GetPeerSigners() []*x509.Certificate } diff --git a/controller/model/external_jwt_signer_manager.go b/controller/model/external_jwt_signer_manager.go index be9bacf85..8118f5723 100644 --- a/controller/model/external_jwt_signer_manager.go +++ b/controller/model/external_jwt_signer_manager.go @@ -25,7 +25,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -37,7 +36,7 @@ func NewExternalJwtSignerManager(env Env) *ExternalJwtSignerManager { } manager.impl = manager - network.RegisterManagerDecoder[*ExternalJwtSigner](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*ExternalJwtSigner](env, manager) return manager } @@ -51,7 +50,7 @@ func (self *ExternalJwtSignerManager) newModelEntity() *ExternalJwtSigner { } func (self *ExternalJwtSignerManager) Create(entity *ExternalJwtSigner, ctx *change.Context) error { - return network.DispatchCreate[*ExternalJwtSigner](self, entity, ctx) + return DispatchCreate[*ExternalJwtSigner](self, entity, ctx) } func (self *ExternalJwtSignerManager) ApplyCreate(cmd *command.CreateEntityCommand[*ExternalJwtSigner], ctx boltz.MutateContext) error { @@ -60,7 +59,7 @@ func (self *ExternalJwtSignerManager) ApplyCreate(cmd *command.CreateEntityComma } func (self *ExternalJwtSignerManager) Update(entity *ExternalJwtSigner, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*ExternalJwtSigner](self, entity, checker, ctx) + return DispatchUpdate[*ExternalJwtSigner](self, entity, checker, ctx) } func (self *ExternalJwtSignerManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*ExternalJwtSigner], ctx boltz.MutateContext) error { diff --git a/controller/model/identity_manager.go b/controller/model/identity_manager.go index 2064df14c..e13d8dd42 100644 --- a/controller/model/identity_manager.go +++ b/controller/model/identity_manager.go @@ -33,7 +33,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" cmap "github.com/orcaman/concurrent-map/v2" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" @@ -66,7 +65,7 @@ func NewIdentityManager(env Env) *IdentityManager { } manager.impl = manager - network.RegisterManagerDecoder[*Identity](env.GetHostController().GetNetwork().GetManagers(), manager) + RegisterManagerDecoder[*Identity](env, manager) RegisterCommand(env, &CreateIdentityWithEnrollmentsCmd{}, &edge_cmd_pb.CreateIdentityWithEnrollmentsCmd{}) RegisterCommand(env, &UpdateServiceConfigsCmd{}, &edge_cmd_pb.UpdateServiceConfigsCmd{}) @@ -78,7 +77,7 @@ func (self *IdentityManager) newModelEntity() *Identity { } func (self *IdentityManager) Create(entity *Identity, ctx *change.Context) error { - return network.DispatchCreate[*Identity](self, entity, ctx) + return DispatchCreate[*Identity](self, entity, ctx) } func (self *IdentityManager) ApplyCreate(cmd *command.CreateEntityCommand[*Identity], ctx boltz.MutateContext) error { @@ -138,7 +137,7 @@ func (self *IdentityManager) ApplyCreateWithEnrollments(cmd *CreateIdentityWithE } func (self *IdentityManager) Update(entity *Identity, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*Identity](self, entity, checker, ctx) + return DispatchUpdate[*Identity](self, entity, checker, ctx) } func (self *IdentityManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Identity], ctx boltz.MutateContext) error { @@ -344,7 +343,7 @@ func (self *IdentityManager) CreateWithAuthenticator(identity *Identity, authent return "", "", apiErr } - err = self.env.GetDbProvider().GetDb().Update(ctx.NewMutateContext(), func(ctx boltz.MutateContext) error { + err = self.env.GetDb().Update(ctx.NewMutateContext(), func(ctx boltz.MutateContext) error { boltIdentity, err := identity.toBoltEntityForCreate(ctx.Tx(), self.env) if err != nil { diff --git a/controller/network/link_controller.go b/controller/model/link_manager.go similarity index 56% rename from controller/network/link_controller.go rename to controller/model/link_manager.go index 4d52d76ba..547586357 100644 --- a/controller/network/link_controller.go +++ b/controller/model/link_manager.go @@ -14,27 +14,22 @@ limitations under the License. */ -package network +package model import ( - "encoding/json" - "errors" "github.com/michaelquigley/pfxlog" - "github.com/openziti/channel/v2/protobufs" "github.com/openziti/foundation/v2/info" "github.com/openziti/storage/objectz" - "github.com/openziti/ziti/common/inspect" - "github.com/openziti/ziti/common/pb/ctrl_pb" - "github.com/openziti/ziti/common/pb/mgmt_pb" + "github.com/openziti/ziti/common/datastructures" + "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/idgen" "github.com/orcaman/concurrent-map/v2" "math" - "strings" "sync" "time" ) -type linkController struct { +type LinkManager struct { linkTable *linkTable idGenerator idgen.Generator lock sync.Mutex @@ -42,20 +37,20 @@ type linkController struct { store *objectz.ObjectStore[*Link] } -func newLinkController(options *Options) *linkController { - initialLatency := DefaultOptionsInitialLinkLatency - if options != nil { - initialLatency = options.InitialLinkLatency +func NewLinkManager(env Env) *LinkManager { + initialLatency := config.DefaultOptionsInitialLinkLatency + if env != nil { + initialLatency = env.GetConfig().Network.InitialLinkLatency } - result := &linkController{ + result := &LinkManager{ linkTable: newLinkTable(), idGenerator: idgen.NewGenerator(), initialLatency: initialLatency, } result.store = objectz.NewObjectStore[*Link](func() objectz.ObjectIterator[*Link] { - return IterateCMap[*Link](result.linkTable.links) + return datastructures.IterateCMap[*Link](result.linkTable.links) }) result.store.AddStringSymbol("id", func(entity *Link) *string { @@ -101,8 +96,12 @@ func newLinkController(options *Options) *linkController { return result } -func (linkController *linkController) buildRouterLinks(router *Router) { - linkController.linkTable.links.IterCb(func(_ string, link *Link) { +func (self *LinkManager) GetStore() *objectz.ObjectStore[*Link] { + return self.store +} + +func (self *LinkManager) BuildRouterLinks(router *Router) { + self.linkTable.links.IterCb(func(_ string, link *Link) { if link.DstId == router.Id { router.routerLinks.Add(link, link.Src.Id) link.Dst.Store(router) @@ -110,36 +109,36 @@ func (linkController *linkController) buildRouterLinks(router *Router) { }) } -func (linkController *linkController) add(link *Link) { - linkController.linkTable.add(link) +func (self *LinkManager) Add(link *Link) { + self.linkTable.add(link) link.Src.routerLinks.Add(link, link.DstId) if dest := link.GetDest(); dest != nil { dest.routerLinks.Add(link, link.Src.Id) } } -func (linkController *linkController) has(link *Link) bool { - return linkController.linkTable.has(link) +func (self *LinkManager) has(link *Link) bool { + return self.linkTable.has(link) } -func (linkController *linkController) scanForDeadLinks() { +func (self *LinkManager) ScanForDeadLinks() { var toRemove []*Link - linkController.linkTable.links.IterCb(func(_ string, link *Link) { + self.linkTable.links.IterCb(func(_ string, link *Link) { if !link.Src.Connected.Load() { toRemove = append(toRemove, link) } }) for _, link := range toRemove { - linkController.remove(link) + self.Remove(link) } } -func (linkController *linkController) routerReportedLink(linkId string, iteration uint32, linkProtocol, dialAddress string, src, dst *Router, dstId string) (*Link, bool) { - linkController.lock.Lock() - defer linkController.lock.Unlock() +func (self *LinkManager) RouterReportedLink(linkId string, iteration uint32, linkProtocol, dialAddress string, src, dst *Router, dstId string) (*Link, bool) { + self.lock.Lock() + defer self.lock.Unlock() - link, _ := linkController.get(linkId) + link, _ := self.Get(linkId) if link != nil && link.Iteration >= iteration { return link, false } @@ -152,30 +151,38 @@ func (linkController *linkController) routerReportedLink(linkId string, iteratio WithField("destRouterId", dstId). WithField("iteration", iteration) - linkController.remove(link) + self.Remove(link) log.Infof("replaced link with newer iteration %v => %v", link.Iteration, iteration) } - link = newLink(linkId, linkProtocol, dialAddress, linkController.initialLatency) + link = newLink(linkId, linkProtocol, dialAddress, self.initialLatency) link.Iteration = iteration link.Src = src link.Dst.Store(dst) link.DstId = dstId link.SetState(Connected) - linkController.add(link) + self.Add(link) return link, true } -func (linkController *linkController) get(linkId string) (*Link, bool) { - return linkController.linkTable.get(linkId) +func (self *LinkManager) Get(linkId string) (*Link, bool) { + return self.linkTable.get(linkId) } -func (linkController *linkController) all() []*Link { - return linkController.linkTable.all() +func (self *LinkManager) All() []*Link { + return self.linkTable.all() +} + +func (self *LinkManager) GetLinkMap() map[string]*Link { + linkMap := make(map[string]*Link) + self.linkTable.links.IterCb(func(key string, link *Link) { + linkMap[key] = link + }) + return linkMap } -func (linkController *linkController) remove(link *Link) { - if linkController.linkTable.remove(link) { +func (self *LinkManager) Remove(link *Link) { + if self.linkTable.remove(link) { link.Src.routerLinks.Remove(link, link.DstId) if dest := link.GetDest(); dest != nil { dest.routerLinks.Remove(link, link.Src.Id) @@ -183,7 +190,7 @@ func (linkController *linkController) remove(link *Link) { } } -func (linkController *linkController) connectedNeighborsOfRouter(router *Router) []*Router { +func (self *LinkManager) ConnectedNeighborsOfRouter(router *Router) []*Router { neighborMap := make(map[string]*Router) links := router.routerLinks.GetLinks() @@ -206,7 +213,7 @@ func (linkController *linkController) connectedNeighborsOfRouter(router *Router) return neighbors } -func (linkController *linkController) leastExpensiveLink(a, b *Router) (*Link, bool) { +func (self *LinkManager) LeastExpensiveLink(a, b *Router) (*Link, bool) { var selected *Link var cost int64 = math.MaxInt64 @@ -236,7 +243,7 @@ func (linkController *linkController) leastExpensiveLink(a, b *Router) (*Link, b return nil, false } -func (linkController *linkController) missingLinks(routers []*Router, pendingTimeout time.Duration) ([]*Link, error) { +func (self *LinkManager) MissingLinks(routers []*Router, pendingTimeout time.Duration) ([]*Link, error) { // When there's a flood of router connects at startup we can see the same link // as missing multiple times as the new link will be marked as PENDING until it's // connected. Give ourselves a little window to make the connection before we @@ -252,9 +259,9 @@ func (linkController *linkController) missingLinks(routers []*Router, pendingTim for _, dstR := range routers { if srcR != dstR && len(dstR.Listeners) > 0 { for _, listener := range dstR.Listeners { - if !linkController.hasLink(srcR, dstR, listener.GetProtocol(), pendingLimit) { + if !self.hasLink(srcR, dstR, listener.GetProtocol(), pendingLimit) { id := idgen.NewUUIDString() - link := newLink(id, listener.GetProtocol(), listener.GetAddress(), linkController.initialLatency) + link := newLink(id, listener.GetProtocol(), listener.GetAddress(), self.initialLatency) link.Src = srcR link.Dst.Store(dstR) link.DstId = dstR.Id @@ -268,24 +275,24 @@ func (linkController *linkController) missingLinks(routers []*Router, pendingTim return missingLinks, nil } -func (linkController *linkController) clearExpiredPending(pendingTimeout time.Duration) { +func (self *LinkManager) ClearExpiredPending(pendingTimeout time.Duration) { pendingLimit := info.NowInMilliseconds() - pendingTimeout.Milliseconds() - toRemove := linkController.linkTable.matching(func(link *Link) bool { + toRemove := self.linkTable.matching(func(link *Link) bool { state := link.CurrentState() return state.Mode == Pending && state.Timestamp < pendingLimit }) for _, link := range toRemove { - linkController.remove(link) + self.Remove(link) } } -func (linkController *linkController) hasLink(a, b *Router, linkProtocol string, pendingLimit int64) bool { - return linkController.hasDirectedLink(a, b, linkProtocol, pendingLimit) || linkController.hasDirectedLink(b, a, linkProtocol, pendingLimit) +func (self *LinkManager) hasLink(a, b *Router, linkProtocol string, pendingLimit int64) bool { + return self.hasDirectedLink(a, b, linkProtocol, pendingLimit) || self.hasDirectedLink(b, a, linkProtocol, pendingLimit) } -func (linkController *linkController) hasDirectedLink(a, b *Router, linkProtocol string, pendingLimit int64) bool { +func (self *LinkManager) hasDirectedLink(a, b *Router, linkProtocol string, pendingLimit int64) bool { links := a.routerLinks.GetLinks() for _, link := range links { state := link.CurrentState() @@ -298,106 +305,8 @@ func (linkController *linkController) hasDirectedLink(a, b *Router, linkProtocol return false } -func (linkController *linkController) linksInMode(mode LinkMode) []*Link { - return linkController.linkTable.allInMode(mode) -} - -func (self *linkController) ValidateRouterLinks(n *Network, router *Router, cb LinkValidationCallback) { - request := &ctrl_pb.InspectRequest{RequestedValues: []string{"links"}} - resp := &ctrl_pb.InspectResponse{} - respMsg, err := protobufs.MarshalTyped(request).WithTimeout(time.Minute).SendForReply(router.Control) - if err = protobufs.TypedResponse(resp).Unmarshall(respMsg, err); err != nil { - self.reportRouterLinksError(router, err, cb) - return - } - - var linkDetails *inspect.LinksInspectResult - for _, val := range resp.Values { - if val.Name == "links" { - if err = json.Unmarshal([]byte(val.Value), &linkDetails); err != nil { - self.reportRouterLinksError(router, err, cb) - return - } - } - } - - if linkDetails == nil { - if len(resp.Errors) > 0 { - err = errors.New(strings.Join(resp.Errors, ",")) - self.reportRouterLinksError(router, err, cb) - return - } - self.reportRouterLinksError(router, errors.New("no link details returned from router"), cb) - return - } - - linkMap := map[string]*Link{} - - self.linkTable.links.IterCb(func(key string, link *Link) { - linkMap[key] = link - }) - - result := &mgmt_pb.RouterLinkDetails{ - RouterId: router.Id, - RouterName: router.Name, - ValidateSuccess: true, - } - - for _, link := range linkDetails.Links { - detail := &mgmt_pb.RouterLinkDetail{ - LinkId: link.Id, - RouterState: mgmt_pb.LinkState_LinkEstablished, - DestRouterId: link.Dest, - Dialed: link.Dialed, - } - detail.DestConnected = n.ConnectedRouter(link.Dest) - if _, found := linkMap[link.Id]; found { - detail.CtrlState = mgmt_pb.LinkState_LinkEstablished - detail.IsValid = detail.DestConnected - } else { - detail.CtrlState = mgmt_pb.LinkState_LinkUnknown - detail.IsValid = !detail.DestConnected - } - delete(linkMap, link.Id) - result.LinkDetails = append(result.LinkDetails, detail) - } - - for _, link := range linkMap { - related := false - dest := "" - if link.Src.Id == router.Id { - related = true - dest = link.DstId - } else if link.DstId == router.Id { - related = true - dest = link.Src.Id - } - - if related { - detail := &mgmt_pb.RouterLinkDetail{ - LinkId: link.Id, - CtrlState: mgmt_pb.LinkState_LinkEstablished, - DestConnected: n.ConnectedRouter(dest), - RouterState: mgmt_pb.LinkState_LinkUnknown, - IsValid: false, - DestRouterId: dest, - Dialed: link.Src.Id == router.Id, - } - result.LinkDetails = append(result.LinkDetails, detail) - } - } - - cb(result) -} - -func (self *linkController) reportRouterLinksError(router *Router, err error, cb LinkValidationCallback) { - result := &mgmt_pb.RouterLinkDetails{ - RouterId: router.Id, - RouterName: router.Name, - ValidateSuccess: false, - Message: err.Error(), - } - cb(result) +func (self *LinkManager) LinksInMode(mode LinkMode) []*Link { + return self.linkTable.allInMode(mode) } /* diff --git a/controller/network/link_controller_test.go b/controller/model/link_manager_test.go similarity index 75% rename from controller/network/link_controller_test.go rename to controller/model/link_manager_test.go index b257af6b6..a79852781 100644 --- a/controller/network/link_controller_test.go +++ b/controller/model/link_manager_test.go @@ -14,7 +14,7 @@ limitations under the License. */ -package network +package model import ( "sync/atomic" @@ -39,7 +39,7 @@ func Test64BitAlignment(t *testing.T) { } func TestLifecycle(t *testing.T) { - linkController := newLinkController(nil) + linkController := NewLinkManager(nil) r0 := NewRouter("r0", "", "", 0, true) r1 := NewRouter("r1", "", "", 0, true) @@ -50,7 +50,7 @@ func TestLifecycle(t *testing.T) { } l0.Dst.Store(r1) - linkController.add(l0) + linkController.Add(l0) assert.True(t, linkController.has(l0)) links := r0.routerLinks.GetLinks() @@ -61,7 +61,7 @@ func TestLifecycle(t *testing.T) { assert.Equal(t, 1, len(links)) assert.Equal(t, l0, links[0]) - linkController.remove(l0) + linkController.Remove(l0) assert.False(t, linkController.has(l0)) links = r0.routerLinks.GetLinks() @@ -72,25 +72,15 @@ func TestLifecycle(t *testing.T) { } func TestNeighbors(t *testing.T) { - linkController := newLinkController(nil) + linkController := NewLinkManager(nil) - r0 := newRouterForTest("r0", "", nil, nil, 0, true) - r1 := newRouterForTest("r1", "", nil, nil, 0, true) - l0 := newTestLink("l0", r0, r1) + r0 := NewRouterForTest("r0", "", nil, nil, 0, true) + r1 := NewRouterForTest("r1", "", nil, nil, 0, true) + l0 := NewTestLink("l0", r0, r1) l0.SetState(Connected) - linkController.add(l0) + linkController.Add(l0) - neighbors := linkController.connectedNeighborsOfRouter(r0) + neighbors := linkController.ConnectedNeighborsOfRouter(r0) assert.Equal(t, 1, len(neighbors)) assert.Equal(t, r1, neighbors[0]) } - -func newTestLink(id string, src, dst *Router) *Link { - l := newLink(id, "tls", "tcp:localhost:1234", 0) - l.Src = src - l.DstId = dst.Id - l.Dst.Store(dst) - src.Connected.Store(true) - dst.Connected.Store(true) - return l -} diff --git a/controller/network/link.go b/controller/model/link_model.go similarity index 95% rename from controller/network/link.go rename to controller/model/link_model.go index b6c5becb9..5f23891d9 100644 --- a/controller/network/link.go +++ b/controller/model/link_model.go @@ -14,7 +14,7 @@ limitations under the License. */ -package network +package model import ( "github.com/openziti/foundation/v2/concurrenz" @@ -55,7 +55,7 @@ func newLink(id string, linkProtocol string, dialAddress string, initialLatency SrcLatency: initialLatency.Nanoseconds(), DstLatency: initialLatency.Nanoseconds(), } - l.recalculateCost() + l.RecalculateCost() l.recalculateUsable() return l } @@ -116,7 +116,7 @@ func (link *Link) GetStaticCost() int32 { func (link *Link) SetStaticCost(cost int32) { atomic.StoreInt32(&link.StaticCost, cost) - link.recalculateCost() + link.RecalculateCost() } func (link *Link) GetSrcLatency() int64 { @@ -125,7 +125,7 @@ func (link *Link) GetSrcLatency() int64 { func (link *Link) SetSrcLatency(latency int64) { atomic.StoreInt64(&link.SrcLatency, latency) - link.recalculateCost() + link.RecalculateCost() } func (link *Link) GetDstLatency() int64 { @@ -134,10 +134,10 @@ func (link *Link) GetDstLatency() int64 { func (link *Link) SetDstLatency(latency int64) { atomic.StoreInt64(&link.DstLatency, latency) - link.recalculateCost() + link.RecalculateCost() } -func (link *Link) recalculateCost() { +func (link *Link) RecalculateCost() { cost := int64(link.GetStaticCost()) + link.GetSrcLatency()/1_000_000 + link.GetDstLatency()/1_000_000 atomic.StoreInt64(&link.Cost, cost) } diff --git a/controller/model/managers.go b/controller/model/managers.go index 32789dd8c..b3abc952f 100644 --- a/controller/model/managers.go +++ b/controller/model/managers.go @@ -17,18 +17,34 @@ package model import ( + "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/openziti/ziti/common/pb/edge_cmd_pb" + "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" - "github.com/openziti/ziti/controller/network" + "github.com/openziti/ziti/controller/fields" + "github.com/openziti/ziti/controller/ioc" + "github.com/openziti/ziti/controller/models" "google.golang.org/protobuf/proto" ) +const ( + CreateDecoder = "CreateDecoder" + UpdateDecoder = "UpdateDecoder" + DeleteDecoder = "DeleteDecoder" +) + type Managers struct { + // command + Registry ioc.Registry + Dispatcher command.Dispatcher + // fabric - Router *network.RouterManager - Service *network.ServiceManager - Terminator *network.TerminatorManager - Command *network.CommandManager + Circuit *CircuitManager + Command *CommandManager + Link *LinkManager + Router *RouterManager + Service *ServiceManager + Terminator *TerminatorManager // edge ApiSession *ApiSessionManager @@ -58,13 +74,20 @@ type Managers struct { AuthPolicy *AuthPolicyManager } -func InitEntityManagers(env Env) *Managers { - managers := &Managers{} +func NewManagers() *Managers { + return &Managers{ + Registry: ioc.NewRegistry(), + } +} - managers.Command = env.GetDbProvider().GetManagers().Command - managers.Router = env.GetDbProvider().GetManagers().Routers - managers.Service = env.GetDbProvider().GetManagers().Services - managers.Terminator = env.GetDbProvider().GetManagers().Terminators +func (managers *Managers) Init(env Env) *Managers { + managers.Dispatcher = env.GetCommandDispatcher() + managers.Circuit = NewCircuitController() + managers.Command = newCommandManager(env, managers.Registry) + managers.Link = NewLinkManager(env) + managers.Router = newRouterManager(env) + managers.Service = newServiceManager(env) + managers.Terminator = newTerminatorManager(env) managers.ApiSession = NewApiSessionManager(env) managers.ApiSessionCertificate = NewApiSessionCertificateManager(env) @@ -93,6 +116,7 @@ func InitEntityManagers(env Env) *Managers { managers.Mfa = NewMfaManager(env) RegisterCommand(env, &CreateEdgeTerminatorCmd{}, &edge_cmd_pb.CreateEdgeTerminatorCommand{}) + managers.Command.registerGenericCommands() return managers } @@ -116,7 +140,7 @@ type decodableCommand[T any, M any] interface { // // We only have both types specified so that we can enforce that each is a pointer type. If didn't // enforce that the instances were pointer types, we couldn't use new to instantiate new instances. -func RegisterCommand[MT any, CT any, M network.CommandMsg[MT], C decodableCommand[CT, M]](env Env, _ C, _ M) { +func RegisterCommand[MT any, CT any, M CommandMsg[MT], C decodableCommand[CT, M]](env Env, _ C, _ M) { decoder := func(commandType int32, data []byte) (command.Command, error) { var msg M = new(MT) if err := proto.Unmarshal(data, msg); err != nil { @@ -131,5 +155,61 @@ func RegisterCommand[MT any, CT any, M network.CommandMsg[MT], C decodableComman } var msg M = new(MT) - env.GetHostController().GetNetwork().Managers.Command.Decoders.RegisterF(msg.GetCommandType(), decoder) + env.GetManagers().Command.Decoders.RegisterF(msg.GetCommandType(), decoder) +} + +type createDecoderF func(cmd *cmd_pb.CreateEntityCommand) (command.Command, error) + +func RegisterCreateDecoder[T models.Entity](env Env, creator command.EntityCreator[T]) { + entityType := creator.GetEntityTypeId() + env.GetManagers().Registry.RegisterSingleton(entityType+CreateDecoder, createDecoderF(func(cmd *cmd_pb.CreateEntityCommand) (command.Command, error) { + entity, err := creator.Unmarshall(cmd.EntityData) + if err != nil { + return nil, err + } + return &command.CreateEntityCommand[T]{ + Context: change.FromProtoBuf(cmd.Ctx), + Entity: entity, + Creator: creator, + Flags: cmd.Flags, + }, nil + })) +} + +type updateDecoderF func(cmd *cmd_pb.UpdateEntityCommand) (command.Command, error) + +func RegisterUpdateDecoder[T models.Entity](env Env, updater command.EntityUpdater[T]) { + entityType := updater.GetEntityTypeId() + env.GetManagers().Registry.RegisterSingleton(entityType+UpdateDecoder, updateDecoderF(func(cmd *cmd_pb.UpdateEntityCommand) (command.Command, error) { + entity, err := updater.Unmarshall(cmd.EntityData) + if err != nil { + return nil, err + } + return &command.UpdateEntityCommand[T]{ + Context: change.FromProtoBuf(cmd.Ctx), + Entity: entity, + Updater: updater, + UpdatedFields: fields.SliceToUpdatedFields(cmd.UpdatedFields), + Flags: cmd.Flags, + }, nil + })) +} + +type deleteDecoderF func(cmd *cmd_pb.DeleteEntityCommand) (command.Command, error) + +func RegisterDeleteDecoder(env Env, deleter command.EntityDeleter) { + entityType := deleter.GetEntityTypeId() + env.GetManagers().Registry.RegisterSingleton(entityType+DeleteDecoder, deleteDecoderF(func(cmd *cmd_pb.DeleteEntityCommand) (command.Command, error) { + return &command.DeleteEntityCommand{ + Context: change.FromProtoBuf(cmd.Ctx), + Deleter: deleter, + Id: cmd.EntityId, + }, nil + })) +} + +func RegisterManagerDecoder[T models.Entity](env Env, ctrl command.EntityManager[T]) { + RegisterCreateDecoder[T](env, ctrl) + RegisterUpdateDecoder[T](env, ctrl) + RegisterDeleteDecoder(env, ctrl) } diff --git a/controller/model/mfa_manager.go b/controller/model/mfa_manager.go index 121ff3500..d0d0183bb 100644 --- a/controller/model/mfa_manager.go +++ b/controller/model/mfa_manager.go @@ -31,7 +31,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "github.com/skip2/go-qrcode" "go.etcd.io/bbolt" @@ -49,7 +48,7 @@ func NewMfaManager(env Env) *MfaManager { } manager.impl = manager - network.RegisterManagerDecoder[*Mfa](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*Mfa](env, manager) return manager } @@ -99,7 +98,7 @@ func (self *MfaManager) CreateForIdentity(identity *Identity, ctx *change.Contex } func (self *MfaManager) Create(entity *Mfa, ctx *change.Context) error { - return network.DispatchCreate[*Mfa](self, entity, ctx) + return DispatchCreate[*Mfa](self, entity, ctx) } func (self *MfaManager) ApplyCreate(cmd *command.CreateEntityCommand[*Mfa], ctx boltz.MutateContext) error { @@ -122,7 +121,7 @@ func (self *MfaManager) ApplyCreate(cmd *command.CreateEntityCommand[*Mfa], ctx } func (self *MfaManager) Update(entity *Mfa, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*Mfa](self, entity, checker, ctx) + return DispatchUpdate[*Mfa](self, entity, checker, ctx) } func (self *MfaManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Mfa], ctx boltz.MutateContext) error { @@ -235,7 +234,7 @@ func (self *MfaManager) GetProvisioningUrl(mfa *Mfa) string { WindowSize: WindowSizeTOTP, UTC: true, } - return otcConfig.ProvisionURIWithIssuer(mfa.Identity.Name, self.env.GetConfig().Totp.Hostname) + return otcConfig.ProvisionURIWithIssuer(mfa.Identity.Name, self.env.GetConfig().Edge.Totp.Hostname) } func (self *MfaManager) RecreateRecoveryCodes(mfa *Mfa, ctx *change.Context) error { diff --git a/controller/model/path.go b/controller/model/path.go new file mode 100644 index 000000000..6e67c3e9c --- /dev/null +++ b/controller/model/path.go @@ -0,0 +1,96 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package model + +import ( + "fmt" +) + +type Path struct { + Nodes []*Router + Links []*Link + IngressId string + EgressId string + InitiatorLocalAddr string + InitiatorRemoteAddr string + TerminatorLocalAddr string + TerminatorRemoteAddr string +} + +func (self *Path) Cost(minRouterCost uint16) int64 { + var cost int64 + for _, l := range self.Links { + cost += l.GetCost() + } + for _, r := range self.Nodes { + cost += int64(max(r.Cost, minRouterCost)) + } + return cost +} + +func (self *Path) String() string { + if len(self.Nodes) < 1 { + return "{}" + } + if len(self.Links) != len(self.Nodes)-1 { + return "{malformed}" + } + out := fmt.Sprintf("[r/%s]", self.Nodes[0].Id) + for i := 0; i < len(self.Links); i++ { + out += fmt.Sprintf("->[l/%s]", self.Links[i].Id) + out += fmt.Sprintf("->[r/%s]", self.Nodes[i+1].Id) + } + return out +} + +func (self *Path) EqualPath(other *Path) bool { + if len(self.Nodes) != len(other.Nodes) { + return false + } + if len(self.Links) != len(other.Links) { + return false + } + for i := 0; i < len(self.Nodes); i++ { + if self.Nodes[i] != other.Nodes[i] { + return false + } + } + for i := 0; i < len(self.Links); i++ { + if self.Links[i] != other.Links[i] { + return false + } + } + return true +} + +func (self *Path) EgressRouter() *Router { + if len(self.Nodes) > 0 { + return self.Nodes[len(self.Nodes)-1] + } + return nil +} + +func (self *Path) UsesLink(l *Link) bool { + if self.Links != nil { + for _, o := range self.Links { + if o == l { + return true + } + } + } + return false +} diff --git a/controller/model/policy_advisor.go b/controller/model/policy_advisor.go index 456790639..84ef490dd 100644 --- a/controller/model/policy_advisor.go +++ b/controller/model/policy_advisor.go @@ -24,7 +24,7 @@ type AdvisorEdgeRouter struct { type AdvisorServiceReachability struct { Identity *Identity - Service *Service + Service *EdgeService IsBindAllowed bool IsDialAllowed bool IdentityRouterCount int @@ -204,7 +204,7 @@ func (advisor *PolicyAdvisor) getEdgeRouterPolicies(identityId, edgeRouterId str type AdvisorIdentityServiceLinks struct { Identity *Identity - Service *Service + Service *EdgeService Policies []*ServicePolicy } @@ -256,7 +256,7 @@ func (advisor *PolicyAdvisor) getServicePolicies(identityId, serviceId string) ( } type AdvisorServiceEdgeRouterLinks struct { - Service *Service + Service *EdgeService EdgeRouter *EdgeRouter Policies []*ServiceEdgeRouterPolicy } diff --git a/controller/model/posture_check_manager.go b/controller/model/posture_check_manager.go index f398fcea2..5ee9c6ec3 100644 --- a/controller/model/posture_check_manager.go +++ b/controller/model/posture_check_manager.go @@ -27,7 +27,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "strings" @@ -47,7 +46,7 @@ func NewPostureCheckManager(env Env) *PostureCheckManager { cache: cache, } manager.impl = manager - network.RegisterManagerDecoder[*PostureCheck](env.GetHostController().GetNetwork().GetManagers(), manager) + RegisterManagerDecoder[*PostureCheck](env, manager) evictF := func(postureCheckId string) { manager.cache.Remove(postureCheckId) @@ -68,7 +67,7 @@ func (self *PostureCheckManager) newModelEntity() *PostureCheck { } func (self *PostureCheckManager) Create(entity *PostureCheck, ctx *change.Context) error { - return network.DispatchCreate[*PostureCheck](self, entity, ctx) + return DispatchCreate[*PostureCheck](self, entity, ctx) } func (self *PostureCheckManager) ApplyCreate(cmd *command.CreateEntityCommand[*PostureCheck], ctx boltz.MutateContext) error { @@ -77,7 +76,7 @@ func (self *PostureCheckManager) ApplyCreate(cmd *command.CreateEntityCommand[*P } func (self *PostureCheckManager) Update(entity *PostureCheck, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*PostureCheck](self, entity, checker, ctx) + return DispatchUpdate[*PostureCheck](self, entity, checker, ctx) } func (self *PostureCheckManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*PostureCheck], ctx boltz.MutateContext) error { diff --git a/controller/model/posture_response_manager.go b/controller/model/posture_response_manager.go index ddfcf8c02..dc1fc65fe 100644 --- a/controller/model/posture_response_manager.go +++ b/controller/model/posture_response_manager.go @@ -139,7 +139,7 @@ func (self *PostureResponseManager) postureDataUpdated(env Env, identityId strin // Only an issue when timeouts are being used - which aren't right now. env.HandleServiceUpdatedEventForIdentityId(identityId) - err := self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := self.env.GetDb().View(func(tx *bbolt.Tx) error { apiSessionIds, _, err := self.env.GetStores().ApiSession.QueryIds(tx, fmt.Sprintf(`identity = "%v"`, identityId)) if err != nil { @@ -255,7 +255,7 @@ func (self *PostureResponseManager) SetSdkInfo(identityId, apiSessionId string, } type ServiceWithTimeout struct { - Service *Service + Service *EdgeService Timeout int64 } @@ -285,7 +285,7 @@ func (self *PostureResponseManager) GetEndpointStateChangeAffectedServices(timeS if err != nil { pfxlog.Logger().Errorf("error querying for onWake/onUnlock posture checks: %v", err) } else { - err = self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err = self.env.GetDb().View(func(tx *bbolt.Tx) error { cursor := self.env.GetStores().PostureCheck.IterateIds(tx, query) for cursor.IsValid() { @@ -312,7 +312,7 @@ func (self *PostureResponseManager) GetEndpointStateChangeAffectedServices(timeS services := map[string]*ServiceWithTimeout{} if len(affectedChecks) > 0 { - _ = self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + _ = self.env.GetDb().View(func(tx *bbolt.Tx) error { for checkId, timeout := range affectedChecks { policyCursor := self.env.GetStores().PostureCheck.GetRelatedEntitiesCursor(tx, checkId, db.EntityTypeServicePolicies, true) @@ -323,7 +323,7 @@ func (self *PostureResponseManager) GetEndpointStateChangeAffectedServices(timeS if _, ok := services[string(serviceCursor.Current())]; !ok { service, err := self.env.GetStores().EdgeService.LoadById(tx, string(serviceCursor.Current())) if err == nil { - modelService := &Service{} + modelService := &EdgeService{} if err := modelService.fillFrom(self.env, tx, service); err == nil { //use the lowest configured timeout (which is some timeout or no timeout) if existingService, ok := services[service.Id]; !ok || timeout < existingService.Timeout { diff --git a/controller/model/posture_response_model.go b/controller/model/posture_response_model.go index 97bc533d2..84257bab2 100644 --- a/controller/model/posture_response_model.go +++ b/controller/model/posture_response_model.go @@ -55,7 +55,7 @@ func newPostureCache(env Env) *PostureCache { env: env, } - pc.run(env.GetHostController().GetCloseNotifyChannel()) + pc.run(env.GetCloseNotifyChannel()) env.GetStores().ApiSession.AddEntityEventListenerF(pc.ApiSessionCreated, boltz.EntityCreatedAsync) env.GetStores().ApiSession.AddEntityEventListenerF(pc.ApiSessionDeleted, boltz.EntityDeletedAsync) @@ -108,7 +108,7 @@ func (pc *PostureCache) evaluate() { // Requires tracking of which session was last evaluated, kept in lastId. for !done { var sessions []*db.Session - _ = pc.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + _ = pc.env.GetDb().View(func(tx *bbolt.Tx) error { cursor := pc.env.GetStores().Session.IterateIds(tx, ast.BoolNodeTrue) if len(lastId) != 0 { @@ -313,7 +313,7 @@ func (pc *PostureCache) PostureCheckChanged(entity boltz.Entity) { identitiesToNotify := map[string]struct{}{} - _ = pc.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + _ = pc.env.GetDb().View(func(tx *bbolt.Tx) error { servicePolicyCursor := servicePolicyLinks.IterateLinks(tx, []byte(entity.GetId())) for servicePolicyCursor.IsValid() { diff --git a/controller/model/revocation_manager.go b/controller/model/revocation_manager.go index f8fbf4637..9c3b56a83 100644 --- a/controller/model/revocation_manager.go +++ b/controller/model/revocation_manager.go @@ -23,7 +23,6 @@ import ( "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "google.golang.org/protobuf/proto" ) @@ -34,7 +33,7 @@ func NewRevocationManager(env Env) *RevocationManager { } manager.impl = manager - network.RegisterManagerDecoder[*Revocation](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*Revocation](env, manager) return manager } @@ -48,7 +47,7 @@ func (self *RevocationManager) ApplyUpdate(_ *command.UpdateEntityCommand[*Revoc } func (self *RevocationManager) Create(entity *Revocation, ctx *change.Context) error { - return network.DispatchCreate[*Revocation](self, entity, ctx) + return DispatchCreate[*Revocation](self, entity, ctx) } func (self *RevocationManager) ApplyCreate(cmd *command.CreateEntityCommand[*Revocation], ctx boltz.MutateContext) error { diff --git a/controller/model/revocation_model.go b/controller/model/revocation_model.go index 51404e3ae..9d34e6b41 100644 --- a/controller/model/revocation_model.go +++ b/controller/model/revocation_model.go @@ -29,7 +29,7 @@ type Revocation struct { ExpiresAt time.Time } -func (entity *Revocation) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, checker boltz.FieldChecker) (*db.Revocation, error) { +func (entity *Revocation) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.Revocation, error) { return entity.toBoltEntityForCreate(tx, env) } diff --git a/controller/network/router.go b/controller/model/router_manager.go similarity index 74% rename from controller/network/router.go rename to controller/model/router_manager.go index e0d71695f..ef574942d 100644 --- a/controller/network/router.go +++ b/controller/model/router_manager.go @@ -14,14 +14,12 @@ limitations under the License. */ -package network +package model import ( "encoding/json" "fmt" "github.com/openziti/channel/v2/protobufs" - "github.com/openziti/foundation/v2/genext" - "github.com/openziti/foundation/v2/versions" "github.com/openziti/ziti/common/inspect" "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/openziti/ziti/common/pb/ctrl_pb" @@ -32,14 +30,12 @@ import ( "github.com/openziti/ziti/controller/xt" "google.golang.org/protobuf/proto" "maps" - "reflect" "strings" "sync" "sync/atomic" "time" "github.com/michaelquigley/pfxlog" - "github.com/openziti/channel/v2" "github.com/openziti/storage/boltz" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/models" @@ -53,68 +49,6 @@ const ( RouterDequiesceFlag uint32 = 2 ) -type Listener interface { - AdvertiseAddress() string - Protocol() string - Groups() []string -} - -type Router struct { - models.BaseEntity - Name string - Fingerprint *string - Listeners []*ctrl_pb.Listener - Control channel.Channel - Connected atomic.Bool - ConnectTime time.Time - VersionInfo *versions.VersionInfo - routerLinks RouterLinks - Cost uint16 - NoTraversal bool - Disabled bool - Metadata *ctrl_pb.RouterMetadata -} - -func (entity *Router) toBolt() *db.Router { - return &db.Router{ - BaseExtEntity: *boltz.NewExtEntity(entity.Id, entity.Tags), - Name: entity.Name, - Fingerprint: entity.Fingerprint, - Cost: entity.Cost, - NoTraversal: entity.NoTraversal, - Disabled: entity.Disabled, - } -} - -func (entity *Router) AddLinkListener(addr, linkProtocol string, linkCostTags []string, groups []string) { - entity.Listeners = append(entity.Listeners, &ctrl_pb.Listener{ - Address: addr, - Protocol: linkProtocol, - CostTags: linkCostTags, - Groups: groups, - }) -} - -func (entity *Router) SetLinkListeners(listeners []*ctrl_pb.Listener) { - entity.Listeners = listeners -} - -func (entity *Router) SetMetadata(metadata *ctrl_pb.RouterMetadata) { - entity.Metadata = metadata -} - -func (entity *Router) HasCapability(capability ctrl_pb.RouterCapability) bool { - return entity.Metadata != nil && genext.Contains(entity.Metadata.Capabilities, capability) -} - -func (entity *Router) SupportsRouterLinkMgmt() bool { - if entity.VersionInfo == nil { - return true - } - supportsLinkMgmt, err := entity.VersionInfo.HasMinimumVersion("0.32.1") - return err != nil || supportsLinkMgmt -} - func NewRouter(id, name, fingerprint string, cost uint16, noTraversal bool) *Router { if name == "" { name = id @@ -135,27 +69,30 @@ type RouterManager struct { baseEntityManager[*Router, *db.Router] cache cmap.ConcurrentMap[string, *Router] connected cmap.ConcurrentMap[string, *Router] - store db.RouterStore } -func newRouterManager(managers *Managers) *RouterManager { +func newRouterManager(env Env) *RouterManager { + routerStore := env.GetStores().Router result := &RouterManager{ - baseEntityManager: newBaseEntityManager[*Router, *db.Router](managers, managers.stores.Router, func() *Router { - return &Router{} - }), - cache: cmap.New[*Router](), - connected: cmap.New[*Router](), - store: managers.stores.Router, + baseEntityManager: newBaseEntityManager[*Router, *db.Router](env, routerStore), + cache: cmap.New[*Router](), + connected: cmap.New[*Router](), } - result.populateEntity = result.populateRouter + result.impl = result + + routerStore.AddEntityIdListener(result.UpdateCachedRouter, boltz.EntityUpdated) + routerStore.AddEntityIdListener(result.HandleRouterDelete, boltz.EntityDeleted) - managers.stores.Router.AddEntityIdListener(result.UpdateCachedRouter, boltz.EntityUpdated) - managers.stores.Router.AddEntityIdListener(result.HandleRouterDelete, boltz.EntityDeleted) + RegisterManagerDecoder[*Router](env, result) return result } -func (self *RouterManager) markConnected(r *Router) { +func (self *RouterManager) newModelEntity() *Router { + return &Router{} +} + +func (self *RouterManager) MarkConnected(r *Router) { if router, _ := self.connected.Get(r.Id); router != nil { if ch := router.Control; ch != nil { if err := ch.Close(); err != nil { @@ -168,7 +105,7 @@ func (self *RouterManager) markConnected(r *Router) { self.connected.Set(r.Id, r) } -func (self *RouterManager) markDisconnected(r *Router) { +func (self *RouterManager) MarkDisconnected(r *Router) { r.Connected.Store(false) self.connected.RemoveCb(r.Id, func(key string, v *Router, exists bool) bool { if exists && v != r { @@ -184,14 +121,14 @@ func (self *RouterManager) IsConnected(id string) bool { return self.connected.Has(id) } -func (self *RouterManager) getConnected(id string) *Router { +func (self *RouterManager) GetConnected(id string) *Router { if router, found := self.connected.Get(id); found { return router } return nil } -func (self *RouterManager) allConnected() []*Router { +func (self *RouterManager) AllConnected() []*Router { var routers []*Router self.connected.IterCb(func(_ string, router *Router) { routers = append(routers, router) @@ -199,7 +136,7 @@ func (self *RouterManager) allConnected() []*Router { return routers } -func (self *RouterManager) connectedCount() int { +func (self *RouterManager) ConnectedCount() int { return self.connected.Count() } @@ -209,17 +146,15 @@ func (self *RouterManager) Create(entity *Router, ctx *change.Context) error { func (self *RouterManager) ApplyCreate(cmd *command.CreateEntityCommand[*Router], ctx boltz.MutateContext) error { router := cmd.Entity - err := self.db.Update(ctx, func(ctx boltz.MutateContext) error { - return self.store.Create(ctx, router.toBolt()) - }) + routerId, err := self.createEntity(router, ctx) if err == nil { - self.cache.Set(router.Id, router) + self.cache.Set(routerId, router) } return err } func (self *RouterManager) Read(id string) (entity *Router, err error) { - err = self.db.View(func(tx *bbolt.Tx) error { + err = self.GetDb().View(func(tx *bbolt.Tx) error { entity, err = self.readInTx(tx, id) return err }) @@ -231,8 +166,8 @@ func (self *RouterManager) Read(id string) (entity *Router, err error) { func (self *RouterManager) Exists(id string) (bool, error) { exists := false - err := self.db.View(func(tx *bbolt.Tx) error { - exists = self.store.IsEntityPresent(tx, id) + err := self.GetDb().View(func(tx *bbolt.Tx) error { + exists = self.Store.IsEntityPresent(tx, id) return nil }) return exists, err @@ -240,7 +175,7 @@ func (self *RouterManager) Exists(id string) (bool, error) { func (self *RouterManager) readUncached(id string) (*Router, error) { entity := &Router{} - err := self.db.View(func(tx *bbolt.Tx) error { + err := self.GetDb().View(func(tx *bbolt.Tx) error { return self.readEntityInTx(tx, id, entity) }) if err != nil { @@ -263,20 +198,6 @@ func (self *RouterManager) readInTx(tx *bbolt.Tx, id string) (*Router, error) { return entity, nil } -func (self *RouterManager) populateRouter(entity *Router, _ *bbolt.Tx, boltEntity boltz.Entity) error { - boltRouter, ok := boltEntity.(*db.Router) - if !ok { - return errors.Errorf("unexpected type %v when filling model router", reflect.TypeOf(boltEntity)) - } - entity.Name = boltRouter.Name - entity.Fingerprint = boltRouter.Fingerprint - entity.Cost = boltRouter.Cost - entity.NoTraversal = boltRouter.NoTraversal - entity.Disabled = boltRouter.Disabled - entity.FillCommon(boltRouter) - return nil -} - func (self *RouterManager) Update(entity *Router, updatedFields fields.UpdatedFields, ctx *change.Context) error { return DispatchUpdate[*Router](self, entity, updatedFields, ctx) } @@ -288,7 +209,7 @@ func (self *RouterManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Router] return self.ApplyDequiesce(cmd, ctx) } - return self.updateGeneral(ctx, cmd.Entity, cmd.UpdatedFields) + return self.updateEntity(cmd.Entity, cmd.UpdatedFields, ctx) } // QuiesceRouter marks all terminators on the router as failed, so that new traffic will avoid this router, if there's @@ -328,7 +249,7 @@ func (self *RouterManager) ApplyQuiesce(cmd *command.UpdateEntityCommand[*Router terminator.SavedPrecedence = ¤tPrecedence terminator.Precedence = xt.Precedences.Failed.String() - return self.Terminators.store.Update(ctx.GetSystemContext(), terminator, boltz.MapFieldChecker{ + return self.env.GetStores().Terminator.Update(ctx.GetSystemContext(), terminator, boltz.MapFieldChecker{ db.FieldTerminatorPrecedence: struct{}{}, db.FieldTerminatorSavedPrecedence: struct{}{}, }) @@ -344,7 +265,7 @@ func (self *RouterManager) ApplyDequiesce(cmd *command.UpdateEntityCommand[*Rout terminator.Precedence = *terminator.SavedPrecedence terminator.SavedPrecedence = nil - return self.Terminators.store.Update(ctx.GetSystemContext(), terminator, boltz.MapFieldChecker{ + return self.env.GetStores().Terminator.Update(ctx.GetSystemContext(), terminator, boltz.MapFieldChecker{ db.FieldTerminatorPrecedence: struct{}{}, db.FieldTerminatorSavedPrecedence: struct{}{}, }) @@ -352,10 +273,10 @@ func (self *RouterManager) ApplyDequiesce(cmd *command.UpdateEntityCommand[*Rout } func (self *RouterManager) UpdateTerminators(router *Router, ctx boltz.MutateContext, f func(terminator *db.Terminator) error) error { - return self.db.Update(ctx, func(ctx boltz.MutateContext) error { - terminatorIds := self.store.GetRelatedEntitiesIdList(ctx.Tx(), router.Id, db.EntityTypeTerminators) + return self.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { + terminatorIds := self.Store.GetRelatedEntitiesIdList(ctx.Tx(), router.Id, db.EntityTypeTerminators) for _, terminatorId := range terminatorIds { - terminator, _, err := self.Terminators.store.FindById(ctx.Tx(), terminatorId) + terminator, _, err := self.env.GetStores().Terminator.FindById(ctx.Tx(), terminatorId) if err != nil { return err } @@ -384,11 +305,6 @@ func (self *RouterManager) HandleRouterDelete(id string) { } else { log.Debug("deleted router not connected, no further action required") } - - go func() { - self.network.routerDeleted(id) - self.Managers.RouterMessaging.RouterDeleted(id) - }() } func (self *RouterManager) UpdateCachedRouter(id string) { @@ -475,12 +391,12 @@ func (self *RouterManager) Unmarshall(bytes []byte) (*Router, error) { }, nil } -func (self *RouterManager) ValidateRouterSdkTerminators(router *Router, cb SdkTerminatorValidationCallback) { +func (self *RouterManager) ValidateRouterSdkTerminators(router *Router, cb func(detail *mgmt_pb.RouterSdkTerminatorsDetails)) { request := &ctrl_pb.InspectRequest{RequestedValues: []string{"sdk-terminators"}} resp := &ctrl_pb.InspectResponse{} respMsg, err := protobufs.MarshalTyped(request).WithTimeout(time.Minute).SendForReply(router.Control) if err = protobufs.TypedResponse(resp).Unmarshall(respMsg, err); err != nil { - self.reportRouterSdkTerminatorsError(router, err, cb) + self.ReportRouterSdkTerminatorsError(router, err, cb) return } @@ -488,7 +404,7 @@ func (self *RouterManager) ValidateRouterSdkTerminators(router *Router, cb SdkTe for _, val := range resp.Values { if val.Name == "sdk-terminators" { if err = json.Unmarshal([]byte(val.Value), &inspectResult); err != nil { - self.reportRouterSdkTerminatorsError(router, err, cb) + self.ReportRouterSdkTerminatorsError(router, err, cb) return } } @@ -497,16 +413,16 @@ func (self *RouterManager) ValidateRouterSdkTerminators(router *Router, cb SdkTe if inspectResult == nil { if len(resp.Errors) > 0 { err = errors.New(strings.Join(resp.Errors, ",")) - self.reportRouterSdkTerminatorsError(router, err, cb) + self.ReportRouterSdkTerminatorsError(router, err, cb) return } - self.reportRouterSdkTerminatorsError(router, errors.New("no terminator details returned from router"), cb) + self.ReportRouterSdkTerminatorsError(router, errors.New("no terminator details returned from router"), cb) return } - listResult, err := self.Terminators.BaseList(fmt.Sprintf(`router="%s" and binding="edge" limit none`, router.Id)) + listResult, err := self.env.GetManagers().Terminator.BaseList(fmt.Sprintf(`router="%s" and binding="edge" limit none`, router.Id)) if err != nil { - self.reportRouterSdkTerminatorsError(router, err, cb) + self.ReportRouterSdkTerminatorsError(router, err, cb) return } @@ -559,7 +475,7 @@ func (self *RouterManager) ValidateRouterSdkTerminators(router *Router, cb SdkTe cb(result) } -func (self *RouterManager) reportRouterSdkTerminatorsError(router *Router, err error, cb SdkTerminatorValidationCallback) { +func (self *RouterManager) ReportRouterSdkTerminatorsError(router *Router, err error, cb func(detail *mgmt_pb.RouterSdkTerminatorsDetails)) { result := &mgmt_pb.RouterSdkTerminatorsDetails{ RouterId: router.Id, RouterName: router.Name, diff --git a/controller/model/router_model.go b/controller/model/router_model.go new file mode 100644 index 000000000..cfa7c5d64 --- /dev/null +++ b/controller/model/router_model.go @@ -0,0 +1,110 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package model + +import ( + "github.com/openziti/channel/v2" + "github.com/openziti/foundation/v2/genext" + "github.com/openziti/foundation/v2/versions" + "github.com/openziti/storage/boltz" + "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/controller/models" + "go.etcd.io/bbolt" + "sync/atomic" + "time" +) + +type Listener interface { + AdvertiseAddress() string + Protocol() string + Groups() []string +} + +type Router struct { + models.BaseEntity + Name string + Fingerprint *string + Listeners []*ctrl_pb.Listener + Control channel.Channel + Connected atomic.Bool + ConnectTime time.Time + VersionInfo *versions.VersionInfo + routerLinks RouterLinks + Cost uint16 + NoTraversal bool + Disabled bool + Metadata *ctrl_pb.RouterMetadata +} + +func (entity *Router) GetLinks() []*Link { + return entity.routerLinks.GetLinks() +} + +func (entity *Router) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.Router, error) { + return entity.toBoltEntityForCreate(tx, env) +} + +func (entity *Router) toBoltEntityForCreate(*bbolt.Tx, Env) (*db.Router, error) { + return &db.Router{ + BaseExtEntity: *boltz.NewExtEntity(entity.Id, entity.Tags), + Name: entity.Name, + Fingerprint: entity.Fingerprint, + Cost: entity.Cost, + NoTraversal: entity.NoTraversal, + Disabled: entity.Disabled, + }, nil +} + +func (entity *Router) fillFrom(_ Env, _ *bbolt.Tx, boltRouter *db.Router) error { + entity.Name = boltRouter.Name + entity.Fingerprint = boltRouter.Fingerprint + entity.Cost = boltRouter.Cost + entity.NoTraversal = boltRouter.NoTraversal + entity.Disabled = boltRouter.Disabled + entity.FillCommon(boltRouter) + return nil +} + +func (entity *Router) AddLinkListener(addr, linkProtocol string, linkCostTags []string, groups []string) { + entity.Listeners = append(entity.Listeners, &ctrl_pb.Listener{ + Address: addr, + Protocol: linkProtocol, + CostTags: linkCostTags, + Groups: groups, + }) +} + +func (entity *Router) SetLinkListeners(listeners []*ctrl_pb.Listener) { + entity.Listeners = listeners +} + +func (entity *Router) SetMetadata(metadata *ctrl_pb.RouterMetadata) { + entity.Metadata = metadata +} + +func (entity *Router) HasCapability(capability ctrl_pb.RouterCapability) bool { + return entity.Metadata != nil && genext.Contains(entity.Metadata.Capabilities, capability) +} + +func (entity *Router) SupportsRouterLinkMgmt() bool { + if entity.VersionInfo == nil { + return true + } + supportsLinkMgmt, err := entity.VersionInfo.HasMinimumVersion("0.32.1") + return err != nil || supportsLinkMgmt +} diff --git a/controller/model/service_edge_router_policy_manager.go b/controller/model/service_edge_router_policy_manager.go index 6e1fea60b..2ddbbd2a2 100644 --- a/controller/model/service_edge_router_policy_manager.go +++ b/controller/model/service_edge_router_policy_manager.go @@ -24,7 +24,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "google.golang.org/protobuf/proto" ) @@ -34,7 +33,7 @@ func NewServiceEdgeRouterPolicyManager(env Env) *ServiceEdgeRouterPolicyManager } manager.impl = manager - network.RegisterManagerDecoder[*ServiceEdgeRouterPolicy](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*ServiceEdgeRouterPolicy](env, manager) return manager } @@ -48,7 +47,7 @@ func (self *ServiceEdgeRouterPolicyManager) newModelEntity() *ServiceEdgeRouterP } func (self *ServiceEdgeRouterPolicyManager) Create(entity *ServiceEdgeRouterPolicy, ctx *change.Context) error { - return network.DispatchCreate[*ServiceEdgeRouterPolicy](self, entity, ctx) + return DispatchCreate[*ServiceEdgeRouterPolicy](self, entity, ctx) } func (self *ServiceEdgeRouterPolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[*ServiceEdgeRouterPolicy], ctx boltz.MutateContext) error { @@ -57,7 +56,7 @@ func (self *ServiceEdgeRouterPolicyManager) ApplyCreate(cmd *command.CreateEntit } func (self *ServiceEdgeRouterPolicyManager) Update(entity *ServiceEdgeRouterPolicy, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*ServiceEdgeRouterPolicy](self, entity, checker, ctx) + return DispatchUpdate[*ServiceEdgeRouterPolicy](self, entity, checker, ctx) } func (self *ServiceEdgeRouterPolicyManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*ServiceEdgeRouterPolicy], ctx boltz.MutateContext) error { diff --git a/controller/network/service.go b/controller/model/service_manager.go similarity index 64% rename from controller/network/service.go rename to controller/model/service_manager.go index c60d9a1b3..56b25e25d 100644 --- a/controller/network/service.go +++ b/controller/model/service_manager.go @@ -14,7 +14,7 @@ limitations under the License. */ -package network +package model import ( "github.com/michaelquigley/pfxlog" @@ -26,45 +26,21 @@ import ( "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" "github.com/orcaman/concurrent-map/v2" - "github.com/pkg/errors" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" - "reflect" "time" ) -type Service struct { - models.BaseEntity - Name string - TerminatorStrategy string - Terminators []*Terminator - MaxIdleTime time.Duration -} - -func (self *Service) GetName() string { - return self.Name -} - -func (entity *Service) toBolt() *db.Service { - return &db.Service{ - BaseExtEntity: *boltz.NewExtEntity(entity.Id, entity.Tags), - Name: entity.Name, - MaxIdleTime: entity.MaxIdleTime, - TerminatorStrategy: entity.TerminatorStrategy, - } -} - -func newServiceManager(managers *Managers) *ServiceManager { +func newServiceManager(env Env) *ServiceManager { result := &ServiceManager{ - baseEntityManager: newBaseEntityManager[*Service, *db.Service](managers, managers.stores.Service, func() *Service { - return &Service{} - }), - cache: cmap.New[*Service](), - store: managers.stores.Service, + baseEntityManager: newBaseEntityManager[*Service, *db.Service](env, env.GetStores().Service), + cache: cmap.New[*Service](), } - result.populateEntity = result.populateService + result.impl = result + + env.GetStores().Service.AddEntityIdListener(result.RemoveFromCache, boltz.EntityUpdated, boltz.EntityDeleted) - managers.stores.Service.AddEntityIdListener(result.RemoveFromCache, boltz.EntityUpdated, boltz.EntityDeleted) + RegisterManagerDecoder[*Service](env, result) return result } @@ -72,15 +48,18 @@ func newServiceManager(managers *Managers) *ServiceManager { type ServiceManager struct { baseEntityManager[*Service, *db.Service] cache cmap.ConcurrentMap[string, *Service] - store db.ServiceStore +} + +func (self *ServiceManager) newModelEntity() *Service { + return &Service{} } func (self *ServiceManager) NotifyTerminatorChanged(terminator *db.Terminator) *db.Terminator { // patched entities may not have all fields, if service is blank, load terminator serviceId := terminator.Service if serviceId == "" { - err := self.db.View(func(tx *bbolt.Tx) error { - t, _, err := self.stores.Terminator.FindById(tx, terminator.Id) + err := self.GetDb().View(func(tx *bbolt.Tx) error { + t, _, err := self.env.GetStores().Terminator.FindById(tx, terminator.Id) if t != nil { terminator = t } @@ -102,21 +81,8 @@ func (self *ServiceManager) Create(entity *Service, ctx *change.Context) error { } func (self *ServiceManager) ApplyCreate(cmd *command.CreateEntityCommand[*Service], ctx boltz.MutateContext) error { - s := cmd.Entity - err := self.db.Update(ctx, func(ctx boltz.MutateContext) error { - if err := self.ValidateNameOnCreate(ctx.Tx(), s); err != nil { - return err - } - if err := self.store.Create(ctx, s.toBolt()); err != nil { - return err - } - return nil - }) - if err != nil { - return err - } - // don't cache, wait for first read. entity may not match data store as data store may have set defaults - return nil + _, err := self.createEntity(cmd.Entity, ctx) + return err } func (self *ServiceManager) Update(entity *Service, updatedFields fields.UpdatedFields, ctx *change.Context) error { @@ -124,7 +90,7 @@ func (self *ServiceManager) Update(entity *Service, updatedFields fields.Updated } func (self *ServiceManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Service], ctx boltz.MutateContext) error { - if err := self.updateGeneral(ctx, cmd.Entity, cmd.UpdatedFields); err != nil { + if err := self.updateEntity(cmd.Entity, cmd.UpdatedFields, ctx); err != nil { return err } self.RemoveFromCache(cmd.Entity.Id) @@ -132,7 +98,7 @@ func (self *ServiceManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Servic } func (self *ServiceManager) Read(id string) (entity *Service, err error) { - err = self.db.View(func(tx *bbolt.Tx) error { + err = self.GetDb().View(func(tx *bbolt.Tx) error { entity, err = self.readInTx(tx, id) return err }) @@ -144,8 +110,8 @@ func (self *ServiceManager) Read(id string) (entity *Service, err error) { func (self *ServiceManager) GetIdForName(id string) (string, error) { var result []byte - err := self.db.View(func(tx *bbolt.Tx) error { - result = self.store.GetNameIndex().Read(tx, []byte(id)) + err := self.GetDb().View(func(tx *bbolt.Tx) error { + result = self.env.GetStores().Service.GetNameIndex().Read(tx, []byte(id)) return nil }) return string(result), err @@ -165,26 +131,6 @@ func (self *ServiceManager) readInTx(tx *bbolt.Tx, id string) (*Service, error) return entity, nil } -func (self *ServiceManager) populateService(entity *Service, tx *bbolt.Tx, boltEntity boltz.Entity) error { - boltService, ok := boltEntity.(*db.Service) - if !ok { - return errors.Errorf("unexpected type %v when filling model service", reflect.TypeOf(boltEntity)) - } - entity.Name = boltService.Name - entity.MaxIdleTime = boltService.MaxIdleTime - entity.TerminatorStrategy = boltService.TerminatorStrategy - entity.FillCommon(boltService) - - terminatorIds := self.store.GetRelatedEntitiesIdList(tx, entity.Id, db.EntityTypeTerminators) - for _, terminatorId := range terminatorIds { - if terminator, _ := self.Terminators.readInTx(tx, terminatorId); terminator != nil { - entity.Terminators = append(entity.Terminators, terminator) - } - } - - return nil -} - func (self *ServiceManager) cacheService(service *Service) { pfxlog.Logger().Tracef("updated service cache: %v", service.Id) self.cache.Set(service.Id, service) diff --git a/controller/model/service_model.go b/controller/model/service_model.go new file mode 100644 index 000000000..ee2b9d32b --- /dev/null +++ b/controller/model/service_model.go @@ -0,0 +1,66 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package model + +import ( + "github.com/openziti/storage/boltz" + "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/controller/models" + "go.etcd.io/bbolt" + "time" +) + +type Service struct { + models.BaseEntity + Name string + TerminatorStrategy string + Terminators []*Terminator + MaxIdleTime time.Duration +} + +func (entity *Service) GetName() string { + return entity.Name +} + +func (entity *Service) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.Service, error) { + return entity.toBoltEntityForCreate(tx, env) +} + +func (entity *Service) toBoltEntityForCreate(*bbolt.Tx, Env) (*db.Service, error) { + return &db.Service{ + BaseExtEntity: *boltz.NewExtEntity(entity.Id, entity.Tags), + Name: entity.Name, + MaxIdleTime: entity.MaxIdleTime, + TerminatorStrategy: entity.TerminatorStrategy, + }, nil +} + +func (entity *Service) fillFrom(env Env, tx *bbolt.Tx, boltService *db.Service) error { + entity.Name = boltService.Name + entity.MaxIdleTime = boltService.MaxIdleTime + entity.TerminatorStrategy = boltService.TerminatorStrategy + entity.FillCommon(boltService) + + terminatorIds := env.GetStores().Service.GetRelatedEntitiesIdList(tx, entity.Id, db.EntityTypeTerminators) + for _, terminatorId := range terminatorIds { + if terminator, _ := env.GetManagers().Terminator.readInTx(tx, terminatorId); terminator != nil { + entity.Terminators = append(entity.Terminators, terminator) + } + } + + return nil +} diff --git a/controller/model/service_policy_manager.go b/controller/model/service_policy_manager.go index 1b9f2a231..1d02492b8 100644 --- a/controller/model/service_policy_manager.go +++ b/controller/model/service_policy_manager.go @@ -24,7 +24,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" ) @@ -35,7 +34,7 @@ func NewServicePolicyManager(env Env) *ServicePolicyManager { } manager.impl = manager - network.RegisterManagerDecoder[*ServicePolicy](env.GetHostController().GetNetwork().Managers, manager) + RegisterManagerDecoder[*ServicePolicy](env, manager) return manager } @@ -49,7 +48,7 @@ func (self *ServicePolicyManager) newModelEntity() *ServicePolicy { } func (self *ServicePolicyManager) Create(entity *ServicePolicy, ctx *change.Context) error { - return network.DispatchCreate[*ServicePolicy](self, entity, ctx) + return DispatchCreate[*ServicePolicy](self, entity, ctx) } func (self *ServicePolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[*ServicePolicy], ctx boltz.MutateContext) error { @@ -58,7 +57,7 @@ func (self *ServicePolicyManager) ApplyCreate(cmd *command.CreateEntityCommand[* } func (self *ServicePolicyManager) Update(entity *ServicePolicy, checker fields.UpdatedFields, ctx *change.Context) error { - return network.DispatchUpdate[*ServicePolicy](self, entity, checker, ctx) + return DispatchUpdate[*ServicePolicy](self, entity, checker, ctx) } func (self *ServicePolicyManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*ServicePolicy], ctx boltz.MutateContext) error { diff --git a/controller/network/terminator.go b/controller/model/terminator_manager.go similarity index 75% rename from controller/network/terminator.go rename to controller/model/terminator_manager.go index af432aff5..6d2f5856b 100644 --- a/controller/network/terminator.go +++ b/controller/model/terminator_manager.go @@ -11,7 +11,7 @@ limitations under the License. */ -package network +package model import ( "context" @@ -30,7 +30,6 @@ import ( "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" "github.com/openziti/ziti/controller/xt" - "github.com/pkg/errors" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" "reflect" @@ -38,106 +37,26 @@ import ( "time" ) -type Terminator struct { - models.BaseEntity - Service string - Router string - Binding string - Address string - InstanceId string - InstanceSecret []byte - Cost uint16 - Precedence xt.Precedence - PeerData map[uint32][]byte - HostId string - SavedPrecedence xt.Precedence -} - -func (entity *Terminator) GetServiceId() string { - return entity.Service -} - -func (entity *Terminator) GetRouterId() string { - return entity.Router -} - -func (entity *Terminator) GetBinding() string { - return entity.Binding -} - -func (entity *Terminator) GetAddress() string { - return entity.Address -} - -func (entity *Terminator) GetInstanceId() string { - return entity.InstanceId -} - -func (entity *Terminator) GetInstanceSecret() []byte { - return entity.InstanceSecret -} - -func (entity *Terminator) GetCost() uint16 { - return entity.Cost -} - -func (entity *Terminator) GetPrecedence() xt.Precedence { - return entity.Precedence -} - -func (entity *Terminator) GetPeerData() xt.PeerData { - return entity.PeerData -} - -func (entity *Terminator) GetHostId() string { - return entity.HostId -} - -func (entity *Terminator) toBolt() *db.Terminator { - precedence := xt.Precedences.Default.String() - if entity.Precedence != nil { - precedence = entity.Precedence.String() - } - - var savedPrecedence *string - if entity.SavedPrecedence != nil { - precedenceStr := entity.SavedPrecedence.String() - savedPrecedence = &precedenceStr - } - - return &db.Terminator{ - BaseExtEntity: *entity.ToBoltBaseExtEntity(), - Service: entity.Service, - Router: entity.Router, - Binding: entity.Binding, - Address: entity.Address, - InstanceId: entity.InstanceId, - InstanceSecret: entity.InstanceSecret, - Cost: entity.Cost, - Precedence: precedence, - PeerData: entity.PeerData, - HostId: entity.HostId, - SavedPrecedence: savedPrecedence, - } -} - -func newTerminatorManager(managers *Managers) *TerminatorManager { +func newTerminatorManager(env Env) *TerminatorManager { result := &TerminatorManager{ - baseEntityManager: newBaseEntityManager[*Terminator, *db.Terminator](managers, managers.stores.Terminator, func() *Terminator { - return &Terminator{} - }), - store: managers.stores.Terminator, + baseEntityManager: newBaseEntityManager[*Terminator, *db.Terminator](env, env.GetStores().Terminator), } - result.populateEntity = result.populateTerminator + result.impl = result - managers.stores.Terminator.AddEntityIdListener(xt.GlobalCosts().ClearCost, boltz.EntityDeleted) + env.GetStores().Terminator.AddEntityIdListener(xt.GlobalCosts().ClearCost, boltz.EntityDeleted) + + RegisterManagerDecoder[*Terminator](env, result) + RegisterCommand(env, &DeleteTerminatorsBatchCommand{}, &cmd_pb.DeleteTerminatorsBatchCommand{}) return result } type TerminatorManager struct { baseEntityManager[*Terminator, *db.Terminator] - store db.TerminatorStore +} + +func (self *TerminatorManager) newModelEntity() *Terminator { + return &Terminator{} } func (self *TerminatorManager) Create(entity *Terminator, ctx *change.Context) error { @@ -145,16 +64,18 @@ func (self *TerminatorManager) Create(entity *Terminator, ctx *change.Context) e } func (self *TerminatorManager) ApplyCreate(cmd *command.CreateEntityCommand[*Terminator], ctx boltz.MutateContext) error { - return self.db.Update(ctx, func(ctx boltz.MutateContext) error { + return self.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { if cmd.Entity.IsSystemEntity() { ctx = ctx.GetSystemContext() } self.checkBinding(cmd.Entity) - boltTerminator := cmd.Entity.toBolt() - err := self.GetStore().Create(ctx, boltTerminator) + boltTerminator, err := cmd.Entity.toBoltEntityForCreate(ctx.Tx(), self.env) if err != nil { return err } + if err = self.GetStore().Create(ctx, boltTerminator); err != nil { + return err + } if cmd.PostCreateHook != nil { return cmd.PostCreateHook(ctx, cmd.Entity) } @@ -168,12 +89,12 @@ func (self *TerminatorManager) DeleteBatch(ids []string, ctx *change.Context) er Manager: self, Ids: ids, } - return self.Managers.Dispatch(cmd) + return self.Dispatch(cmd) } func (self *TerminatorManager) ApplyDeleteBatch(cmd *DeleteTerminatorsBatchCommand, ctx boltz.MutateContext) error { var errorList errorz.MultipleErrors - err := self.db.Update(ctx, func(ctx boltz.MutateContext) error { + err := self.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { for _, id := range cmd.Ids { if self.Store.IsEntityPresent(ctx.Tx(), id) { if err := self.Store.DeleteById(ctx, id); err != nil { @@ -199,7 +120,7 @@ func (self *TerminatorManager) checkBinding(terminator *Terminator) { } } -func (self *TerminatorManager) handlePrecedenceChange(terminatorId string, precedence xt.Precedence) { +func (self *TerminatorManager) HandlePrecedenceChange(terminatorId string, precedence xt.Precedence) { terminator, err := self.Read(terminatorId) if err != nil { pfxlog.Logger().Errorf("unable to update precedence for terminator %v to %v (%v)", @@ -223,34 +144,18 @@ func (self *TerminatorManager) Update(entity *Terminator, updatedFields fields.U func (self *TerminatorManager) ApplyUpdate(cmd *command.UpdateEntityCommand[*Terminator], ctx boltz.MutateContext) error { terminator := cmd.Entity - return self.db.Update(ctx, func(ctx boltz.MutateContext) error { + return self.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { if cmd.Entity.IsSystemEntity() { ctx = ctx.GetSystemContext() } self.checkBinding(terminator) - return self.GetStore().Update(ctx, terminator.toBolt(), cmd.UpdatedFields) - }) -} - -func (self *TerminatorManager) Read(id string) (entity *Terminator, err error) { - err = self.db.View(func(tx *bbolt.Tx) error { - entity, err = self.readInTx(tx, id) - return err + boltTerminator, err := terminator.toBoltEntityForUpdate(ctx.Tx(), self.env, cmd.UpdatedFields) + if err != nil { + return err + } + return self.GetStore().Update(ctx, boltTerminator, cmd.UpdatedFields) }) - if err != nil { - return nil, err - } - return entity, err -} - -func (self *TerminatorManager) readInTx(tx *bbolt.Tx, id string) (*Terminator, error) { - entity := &Terminator{} - err := self.readEntityInTx(tx, id, entity) - if err != nil { - return nil, err - } - return entity, nil } func (self *TerminatorManager) Query(query string) (*TerminatorListResult, error) { @@ -261,30 +166,6 @@ func (self *TerminatorManager) Query(query string) (*TerminatorListResult, error return result, nil } -func (self *TerminatorManager) populateTerminator(entity *Terminator, _ *bbolt.Tx, boltEntity boltz.Entity) error { - boltTerminator, ok := boltEntity.(*db.Terminator) - if !ok { - return errors.Errorf("unexpected type %v when filling model terminator", reflect.TypeOf(boltEntity)) - } - entity.Service = boltTerminator.Service - entity.Router = boltTerminator.Router - entity.Binding = boltTerminator.Binding - entity.Address = boltTerminator.Address - entity.InstanceId = boltTerminator.InstanceId - entity.InstanceSecret = boltTerminator.InstanceSecret - entity.PeerData = boltTerminator.PeerData - entity.Cost = boltTerminator.Cost - entity.Precedence = xt.GetPrecedenceForName(boltTerminator.Precedence) - entity.HostId = boltTerminator.HostId - entity.FillCommon(boltTerminator) - - if boltTerminator.SavedPrecedence != nil { - entity.SavedPrecedence = xt.GetPrecedenceForName(*boltTerminator.SavedPrecedence) - } - - return nil -} - func (self *TerminatorManager) Marshall(entity *Terminator) ([]byte, error) { tags, err := cmd_pb.EncodeTags(entity.Tags) if err != nil { @@ -408,7 +289,7 @@ func (self *TerminatorManager) ValidateTerminators(filter string, fixInvalid boo } func (self *TerminatorManager) validateTerminatorBatch(fixInvalid bool, routerId string, batch []*Terminator, cb TerminatorValidationCallback) { - router := self.Managers.Routers.getConnected(routerId) + router := self.env.GetManagers().Router.GetConnected(routerId) if router == nil { self.reportError(router, batch, cb, "router off-line") return @@ -470,13 +351,13 @@ func (self *TerminatorManager) newTerminatorDetail(router *Router, terminator *T CreateDate: terminator.CreatedAt.Format(time.RFC3339), } - service, _ := self.Services.Read(terminator.Service) + service, _ := self.env.GetManagers().Service.Read(terminator.Service) if service != nil { detail.ServiceName = service.Name } if router == nil { - router, _ = self.Routers.Read(terminator.Router) + router, _ = self.env.GetManagers().Router.Read(terminator.Router) } if router != nil { @@ -529,8 +410,8 @@ func (self *DeleteTerminatorsBatchCommand) Encode() ([]byte, error) { }) } -func (self *DeleteTerminatorsBatchCommand) Decode(n *Network, msg *cmd_pb.DeleteTerminatorsBatchCommand) error { - self.Manager = n.Terminators +func (self *DeleteTerminatorsBatchCommand) Decode(env Env, msg *cmd_pb.DeleteTerminatorsBatchCommand) error { + self.Manager = env.GetManagers().Terminator self.Ids = msg.EntityIds return nil } diff --git a/controller/model/terminator_model.go b/controller/model/terminator_model.go new file mode 100644 index 000000000..7cb710384 --- /dev/null +++ b/controller/model/terminator_model.go @@ -0,0 +1,116 @@ +package model + +import ( + "github.com/openziti/storage/boltz" + "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/controller/models" + "github.com/openziti/ziti/controller/xt" + "go.etcd.io/bbolt" +) + +type Terminator struct { + models.BaseEntity + Service string + Router string + Binding string + Address string + InstanceId string + InstanceSecret []byte + Cost uint16 + Precedence xt.Precedence + PeerData map[uint32][]byte + HostId string + SavedPrecedence xt.Precedence +} + +func (entity *Terminator) GetServiceId() string { + return entity.Service +} + +func (entity *Terminator) GetRouterId() string { + return entity.Router +} + +func (entity *Terminator) GetBinding() string { + return entity.Binding +} + +func (entity *Terminator) GetAddress() string { + return entity.Address +} + +func (entity *Terminator) GetInstanceId() string { + return entity.InstanceId +} + +func (entity *Terminator) GetInstanceSecret() []byte { + return entity.InstanceSecret +} + +func (entity *Terminator) GetCost() uint16 { + return entity.Cost +} + +func (entity *Terminator) GetPrecedence() xt.Precedence { + return entity.Precedence +} + +func (entity *Terminator) GetPeerData() xt.PeerData { + return entity.PeerData +} + +func (entity *Terminator) GetHostId() string { + return entity.HostId +} + +func (entity *Terminator) toBoltEntityForUpdate(tx *bbolt.Tx, env Env, _ boltz.FieldChecker) (*db.Terminator, error) { + return entity.toBoltEntityForCreate(tx, env) +} + +func (entity *Terminator) toBoltEntityForCreate(*bbolt.Tx, Env) (*db.Terminator, error) { + precedence := xt.Precedences.Default.String() + if entity.Precedence != nil { + precedence = entity.Precedence.String() + } + + var savedPrecedence *string + if entity.SavedPrecedence != nil { + precedenceStr := entity.SavedPrecedence.String() + savedPrecedence = &precedenceStr + } + + return &db.Terminator{ + BaseExtEntity: *entity.ToBoltBaseExtEntity(), + Service: entity.Service, + Router: entity.Router, + Binding: entity.Binding, + Address: entity.Address, + InstanceId: entity.InstanceId, + InstanceSecret: entity.InstanceSecret, + Cost: entity.Cost, + Precedence: precedence, + PeerData: entity.PeerData, + HostId: entity.HostId, + SavedPrecedence: savedPrecedence, + }, nil +} + +func (entity *Terminator) fillFrom(_ Env, _ *bbolt.Tx, boltTerminator *db.Terminator) error { + entity.Service = boltTerminator.Service + entity.Router = boltTerminator.Router + entity.Binding = boltTerminator.Binding + entity.Address = boltTerminator.Address + entity.InstanceId = boltTerminator.InstanceId + entity.InstanceSecret = boltTerminator.InstanceSecret + entity.PeerData = boltTerminator.PeerData + entity.Cost = boltTerminator.Cost + entity.Precedence = xt.GetPrecedenceForName(boltTerminator.Precedence) + entity.HostId = boltTerminator.HostId + entity.FillCommon(boltTerminator) + + if boltTerminator.SavedPrecedence != nil { + entity.SavedPrecedence = xt.GetPrecedenceForName(*boltTerminator.SavedPrecedence) + } + + return nil +} diff --git a/controller/model/testing.go b/controller/model/testing.go index 687797245..5e293b6ad 100644 --- a/controller/model/testing.go +++ b/controller/model/testing.go @@ -19,7 +19,9 @@ package model import ( "crypto/tls" "crypto/x509" - "github.com/openziti/ziti/controller" + "github.com/openziti/channel/v2" + "github.com/openziti/transport/v2" + "github.com/openziti/ziti/controller/models" "testing" "time" @@ -35,77 +37,24 @@ import ( "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/config" - edgeconfig "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/event" "github.com/openziti/ziti/controller/jwtsigner" - "github.com/openziti/ziti/controller/network" ) var _ Env = &TestContext{} -var _ HostController = &testHostController{} - -type testHostController struct { - closeNotify chan struct{} - ctx *TestContext -} - -func (self *testHostController) GetConfig() *controller.Config { - return nil -} - -func (self *testHostController) GetApiAddresses() (map[string][]event.ApiAddress, []byte) { - return nil, nil -} - -func (self *testHostController) GetRaftInfo() (string, string, string) { - return "testaddr", "testid", "testversion" -} - -func (self *testHostController) GetRaftIndex() uint64 { - return 0 -} - -func (self *testHostController) GetPeerSigners() []*x509.Certificate { - return nil -} - -func (self *testHostController) Identity() identity.Identity { - return &identity.TokenId{Token: "test"} -} - -func (self *testHostController) GetNetwork() *network.Network { - return self.ctx.n -} - -func (self *testHostController) Shutdown() { - close(self.closeNotify) -} - -func (self *testHostController) GetCloseNotifyChannel() <-chan struct{} { - return self.closeNotify -} - -func (self *testHostController) Stop() { - close(self.closeNotify) -} - -func (ctx *testHostController) IsRaftEnabled() bool { - return false -} - type TestContext struct { *db.TestContext - n *network.Network managers *Managers - config *edgeconfig.Config + config *config.Config metricsRegistry metrics.Registry - hostController *testHostController + closeNotify chan struct{} + dispatcher command.Dispatcher } -func (ctx *TestContext) GetDbProvider() network.DbProvider { - return ctx.n +func (self *TestContext) GetCloseNotifyChannel() <-chan struct{} { + return self.closeNotify } func (ctx *TestContext) ValidateAccessToken(token string) (*common.AccessClaims, error) { @@ -159,7 +108,7 @@ func (ctx *TestContext) GetManagers() *Managers { return ctx.managers } -func (ctx *TestContext) GetConfig() *edgeconfig.Config { +func (ctx *TestContext) GetConfig() *config.Config { return ctx.config } @@ -187,14 +136,6 @@ func (ctx *TestContext) GetControlClientCsrSigner() cert.Signer { return nil } -func (ctx *TestContext) GetHostController() HostController { - return ctx.hostController -} - -func (ctx *TestContext) GetSchemas() Schemas { - panic("implement me") -} - func (ctx *TestContext) IsEdgeRouterOnline(string) bool { panic("implement me") } @@ -207,41 +148,66 @@ func (ctx *TestContext) GetFingerprintGenerator() cert.FingerprintGenerator { return nil } -func NewTestContext(t *testing.T) *TestContext { +func (self *TestContext) GetApiAddresses() (map[string][]event.ApiAddress, []byte) { + return nil, nil +} + +func (self *TestContext) GetRaftInfo() (string, string, string) { + return "testaddr", "testid", "testversion" +} + +func (self *TestContext) GetPeerSigners() []*x509.Certificate { + return nil +} + +func (self *TestContext) Identity() identity.Identity { + return &identity.TokenId{Token: "test"} +} + +func (self *TestContext) Shutdown() { + close(self.closeNotify) +} + +func (self *TestContext) Stop() { + close(self.closeNotify) +} + +func (self *TestContext) GetCommandDispatcher() command.Dispatcher { + return self.dispatcher +} + +func NewTestContext(t testing.TB) *TestContext { fabricTestContext := db.NewTestContext(t) - context := &TestContext{ + ctx := &TestContext{ TestContext: fabricTestContext, metricsRegistry: metrics.NewRegistry("test", nil), + closeNotify: make(chan struct{}), + dispatcher: &command.LocalDispatcher{ + EncodeDecodeCommands: true, + Limiter: command.NoOpRateLimiter{}, + }, } - context.hostController = &testHostController{ - ctx: context, - closeNotify: make(chan struct{}), - } - - return context -} -func (ctx *TestContext) Init() { ctx.TestContext.Init() - cfg := newTestConfig(ctx.TestContext) - n, err := network.NewNetwork(cfg) - ctx.NoError(err) - ctx.n = n ctx.config = &config.Config{ - Enrollment: config.Enrollment{ - EdgeRouter: config.EnrollmentOption{ - Duration: 60 * time.Second, + Network: config.DefaultNetworkConfig(), + Edge: &config.EdgeConfig{ + Enrollment: config.Enrollment{ + EdgeRouter: config.EnrollmentOption{ + Duration: 60 * time.Second, + }, }, }, } - ctx.managers = InitEntityManagers(ctx) + ctx.managers = NewManagers() + ctx.managers.Init(ctx) + + return ctx } func (ctx *TestContext) Cleanup() { - if ctx.hostController != nil { - ctx.hostController.Stop() - } + ctx.Stop() ctx.TestContext.Cleanup() } @@ -255,8 +221,8 @@ func (ctx *TestContext) requireNewIdentity(isAdmin bool) *Identity { return newIdentity } -func (ctx *TestContext) requireNewService() *Service { - service := &Service{ +func (ctx *TestContext) requireNewService() *EdgeService { + service := &EdgeService{ Name: eid.New(), } ctx.NoError(ctx.managers.EdgeService.Create(service, change.New())) @@ -334,23 +300,23 @@ func ss(vals ...string) []string { return vals } -func newTestConfig(ctx *db.TestContext) *testConfig { - options := network.DefaultOptions() - options.MinRouterCost = 0 - - return &testConfig{ - closeNotify: make(chan struct{}), - ctx: ctx, - options: options, - metricsRegistry: metrics.NewRegistry("test", nil), - versionProvider: versions.NewDefaultVersionProvider(), - } -} +//func newTestConfig(ctx *db.TestContext) *testConfig { +// options := network.DefaultOptions() +// options.MinRouterCost = 0 +// +// return &testConfig{ +// closeNotify: make(chan struct{}), +// ctx: ctx, +// options: options, +// metricsRegistry: metrics.NewRegistry("test", nil), +// versionProvider: versions.NewDefaultVersionProvider(), +// } +//} type testConfig struct { - closeNotify chan struct{} - ctx *db.TestContext - options *network.Options + closeNotify chan struct{} + ctx *db.TestContext + // options *network.Options metricsRegistry metrics.Registry versionProvider versions.VersionProvider } @@ -367,9 +333,9 @@ func (self *testConfig) GetMetricsRegistry() metrics.Registry { return self.metricsRegistry } -func (self *testConfig) GetOptions() *network.Options { - return self.options -} +//func (self *testConfig) GetOptions() *network.Options { +// return self.options +//} func (self *testConfig) GetCommandDispatcher() command.Dispatcher { return &command.LocalDispatcher{ @@ -388,3 +354,28 @@ func (self *testConfig) GetVersionProvider() versions.VersionProvider { func (self *testConfig) GetCloseNotify() <-chan struct{} { return self.closeNotify } + +func NewTestLink(id string, src, dst *Router) *Link { + l := newLink(id, "tls", "tcp:localhost:1234", 0) + l.Src = src + l.DstId = dst.Id + l.Dst.Store(dst) + src.Connected.Store(true) + dst.Connected.Store(true) + return l +} + +func NewRouterForTest(id string, fingerprint string, advLstnr transport.Address, ctrl channel.Channel, cost uint16, noTraversal bool) *Router { + r := &Router{ + BaseEntity: models.BaseEntity{Id: id}, + Name: id, + Fingerprint: &fingerprint, + Control: ctrl, + Cost: cost, + NoTraversal: noTraversal, + } + if advLstnr != nil { + r.AddLinkListener(advLstnr.String(), advLstnr.Type(), []string{"Cost Tag"}, []string{"default"}) + } + return r +} diff --git a/controller/model/transit_router_manager.go b/controller/model/transit_router_manager.go index 0934d75f0..f6af4a547 100644 --- a/controller/model/transit_router_manager.go +++ b/controller/model/transit_router_manager.go @@ -29,7 +29,6 @@ import ( "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/fields" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "github.com/pkg/errors" "go.etcd.io/bbolt" "google.golang.org/protobuf/proto" @@ -46,8 +45,8 @@ func NewTransitRouterManager(env Env) *TransitRouterManager { manager.impl = manager RegisterCommand(env, &CreateTransitRouterCmd{}, &edge_cmd_pb.CreateTransitRouterCmd{}) - network.RegisterUpdateDecoder[*TransitRouter](env.GetHostController().GetNetwork().Managers, manager) - network.RegisterDeleteDecoder(env.GetHostController().GetNetwork().Managers, manager) + RegisterUpdateDecoder[*TransitRouter](env, manager) + RegisterDeleteDecoder(env, manager) return manager } diff --git a/controller/network/assembly.go b/controller/network/assembly.go index 6093bfe9e..ac3d93176 100644 --- a/controller/network/assembly.go +++ b/controller/network/assembly.go @@ -22,6 +22,7 @@ import ( "github.com/openziti/foundation/v2/info" "github.com/openziti/ziti/common/pb/ctrl_pb" "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "time" ) @@ -32,13 +33,13 @@ func (network *Network) assemble() { log := pfxlog.Logger() - if network.Routers.connectedCount() > 1 { - log.Tracef("assembling with [%d] routers", network.Routers.connectedCount()) + if network.Router.ConnectedCount() > 1 { + log.Tracef("assembling with [%d] routers", network.Router.ConnectedCount()) - missingLinks, err := network.linkController.missingLinks(network.Routers.allConnected(), network.options.PendingLinkTimeout) + missingLinks, err := network.Link.MissingLinks(network.Router.AllConnected(), network.options.PendingLinkTimeout) if err == nil { for _, missingLink := range missingLinks { - network.linkController.add(missingLink) + network.Link.Add(missingLink) dial := &ctrl_pb.Dial{ LinkId: missingLink.Id, @@ -68,11 +69,11 @@ func (network *Network) assemble() { log.WithField("err", err).Error("missing link enumeration failed") } - network.linkController.clearExpiredPending(network.options.PendingLinkTimeout) + network.Link.ClearExpiredPending(network.options.PendingLinkTimeout) } } -func (network *Network) NotifyLinkEvent(link *Link, eventType event.LinkEventType) { +func (network *Network) NotifyLinkEvent(link *model.Link, eventType event.LinkEventType) { linkEvent := &event.LinkEvent{ Namespace: event.LinkEventsNs, EventType: eventType, @@ -87,7 +88,7 @@ func (network *Network) NotifyLinkEvent(link *Link, eventType event.LinkEventTyp network.eventDispatcher.AcceptLinkEvent(linkEvent) } -func (network *Network) NotifyLinkConnected(link *Link, msg *ctrl_pb.LinkConnected) { +func (network *Network) NotifyLinkConnected(link *model.Link, msg *ctrl_pb.LinkConnected) { linkEvent := &event.LinkEvent{ Namespace: event.LinkEventsNs, EventType: event.LinkConnected, @@ -124,12 +125,12 @@ func (network *Network) NotifyLinkIdEvent(linkId string, eventType event.LinkEve func (network *Network) clean() { log := pfxlog.Logger() - failedLinks := network.linkController.linksInMode(Failed) - duplicateLinks := network.linkController.linksInMode(Duplicate) + failedLinks := network.Link.LinksInMode(model.Failed) + duplicateLinks := network.Link.LinksInMode(model.Duplicate) failedLinks = append(failedLinks, duplicateLinks...) now := info.NowInMilliseconds() - var lRemove []*Link + var lRemove []*model.Link for _, l := range failedLinks { if now-l.CurrentState().Timestamp >= 30000 { lRemove = append(lRemove, l) @@ -137,6 +138,6 @@ func (network *Network) clean() { } for _, lr := range lRemove { log.WithField("linkId", lr.Id).Info("removing failed link") - network.linkController.remove(lr) + network.Link.Remove(lr) } } diff --git a/controller/network/circuit_lifecycle.go b/controller/network/circuit_lifecycle.go index b315b711d..ee6bc12f9 100644 --- a/controller/network/circuit_lifecycle.go +++ b/controller/network/circuit_lifecycle.go @@ -18,12 +18,13 @@ package network import ( "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/xt" "github.com/pkg/errors" "time" ) -func (network *Network) fillCircuitPath(e *event.CircuitEvent, path *Path) { +func (network *Network) fillCircuitPath(e *event.CircuitEvent, path *model.Path) { if path == nil { return } @@ -42,7 +43,7 @@ func (network *Network) fillCircuitPath(e *event.CircuitEvent, path *Path) { e.LinkCount = len(path.Links) } -func (network *Network) CircuitEvent(eventType event.CircuitEventType, circuit *Circuit, creationTimespan *time.Duration) { +func (network *Network) CircuitEvent(eventType event.CircuitEventType, circuit *model.Circuit, creationTimespan *time.Duration) { var cost *uint32 var duration *time.Duration if eventType == event.CircuitCreated { @@ -126,9 +127,9 @@ func newCircuitErrWrap(cause CircuitFailureCause, err error) CircuitError { func (network *Network) CircuitFailedEvent( circuitId string, - params CreateCircuitParams, + params model.CreateCircuitParams, startTime time.Time, - path *Path, + path *model.Path, t xt.CostedTerminator, cause CircuitFailureCause) { var failureCause *string diff --git a/controller/network/db_provider.go b/controller/network/db_provider.go deleted file mode 100644 index 267a52b71..000000000 --- a/controller/network/db_provider.go +++ /dev/null @@ -1,12 +0,0 @@ -package network - -import ( - "github.com/openziti/storage/boltz" - "github.com/openziti/ziti/controller/db" -) - -type DbProvider interface { - GetDb() boltz.Db - GetStores() *db.Stores - GetManagers() *Managers -} diff --git a/controller/network/fault.go b/controller/network/fault.go index 2d281f82b..4bbcdbebb 100644 --- a/controller/network/fault.go +++ b/controller/network/fault.go @@ -18,11 +18,12 @@ package network import ( "github.com/michaelquigley/pfxlog" + "github.com/openziti/ziti/controller/model" "github.com/sirupsen/logrus" ) type ForwardingFaultReport struct { - R *Router + R *model.Router CircuitIds []string UnknownOwner bool } @@ -31,7 +32,7 @@ func (network *Network) fault(ffr *ForwardingFaultReport) { logrus.Infof("network fault processing for [%d] circuits", len(ffr.CircuitIds)) for _, circuitId := range ffr.CircuitIds { log := pfxlog.Logger().WithField("circuitId", circuitId).WithField("routerId", ffr.R.Id) - s, found := network.circuitController.get(circuitId) + s, found := network.Circuit.Get(circuitId) if found { if success := network.rerouteCircuitWithTries(s, 2); success { log.Info("rerouted circuit in response to forwarding fault from router") diff --git a/controller/network/handler.go b/controller/network/handler.go index 56a5b62b1..15c87a69e 100644 --- a/controller/network/handler.go +++ b/controller/network/handler.go @@ -16,7 +16,9 @@ package network +import "github.com/openziti/ziti/controller/model" + type RouterPresenceHandler interface { - RouterConnected(r *Router) - RouterDisconnected(r *Router) + RouterConnected(r *model.Router) + RouterDisconnected(r *model.Router) } diff --git a/controller/network/inspect.go b/controller/network/inspect.go index 2c9dfa1a9..d1f00be67 100644 --- a/controller/network/inspect.go +++ b/controller/network/inspect.go @@ -23,6 +23,7 @@ import ( "github.com/openziti/channel/v2/protobufs" "github.com/openziti/foundation/v2/concurrenz" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "regexp" "sync" "time" @@ -137,7 +138,7 @@ func (ctx *inspectRequestContext) inspectLocal() { } } -func (ctx *inspectRequestContext) inspectRouter(router *Router) { +func (ctx *inspectRequestContext) inspectRouter(router *model.Router) { log := pfxlog.Logger(). WithField("appRegex", ctx.appRegex). WithField("routerId", router.Id). diff --git a/controller/network/managers.go b/controller/network/managers.go deleted file mode 100644 index cc14dc1b5..000000000 --- a/controller/network/managers.go +++ /dev/null @@ -1,328 +0,0 @@ -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package network - -import ( - "github.com/michaelquigley/pfxlog" - "github.com/openziti/foundation/v2/goroutines" - "github.com/openziti/foundation/v2/versions" - "github.com/openziti/storage/ast" - "github.com/openziti/storage/boltz" - "github.com/openziti/ziti/common/pb/cmd_pb" - "github.com/openziti/ziti/controller/change" - "github.com/openziti/ziti/controller/command" - "github.com/openziti/ziti/controller/db" - "github.com/openziti/ziti/controller/fields" - "github.com/openziti/ziti/controller/idgen" - "github.com/openziti/ziti/controller/ioc" - "github.com/openziti/ziti/controller/models" - "go.etcd.io/bbolt" -) - -const ( - CreateDecoder = "CreateDecoder" - UpdateDecoder = "UpdateDecoder" - DeleteDecoder = "DeleteDecoder" -) - -type Managers struct { - network *Network - db boltz.Db - stores *db.Stores - Terminators *TerminatorManager - Routers *RouterManager - Services *ServiceManager - Inspections *InspectionsManager - Command *CommandManager - Dispatcher command.Dispatcher - Registry ioc.Registry - RouterMessaging *RouterMessaging -} - -func (self *Managers) getDb() boltz.Db { - return self.db -} - -func (self *Managers) Dispatch(command command.Command) error { - return self.Dispatcher.Dispatch(command) -} - -type creator[T models.Entity] interface { - command.EntityCreator[T] - Dispatch(cmd command.Command) error -} - -type updater[T models.Entity] interface { - command.EntityUpdater[T] - Dispatch(cmd command.Command) error -} - -func DispatchCreate[T models.Entity](c creator[T], entity T, ctx *change.Context) error { - if entity.GetId() == "" { - id := idgen.NewUUIDString() - entity.SetId(id) - } - - cmd := &command.CreateEntityCommand[T]{ - Context: ctx, - Creator: c, - Entity: entity, - } - - return c.Dispatch(cmd) -} - -func DispatchUpdate[T models.Entity](u updater[T], entity T, updatedFields fields.UpdatedFields, ctx *change.Context) error { - cmd := &command.UpdateEntityCommand[T]{ - Context: ctx, - Updater: u, - Entity: entity, - UpdatedFields: updatedFields, - } - - return u.Dispatch(cmd) -} - -type createDecoderF func(cmd *cmd_pb.CreateEntityCommand) (command.Command, error) - -func RegisterCreateDecoder[T models.Entity](managers *Managers, creator command.EntityCreator[T]) { - entityType := creator.GetEntityTypeId() - managers.Registry.RegisterSingleton(entityType+CreateDecoder, createDecoderF(func(cmd *cmd_pb.CreateEntityCommand) (command.Command, error) { - entity, err := creator.Unmarshall(cmd.EntityData) - if err != nil { - return nil, err - } - return &command.CreateEntityCommand[T]{ - Context: change.FromProtoBuf(cmd.Ctx), - Entity: entity, - Creator: creator, - Flags: cmd.Flags, - }, nil - })) -} - -type updateDecoderF func(cmd *cmd_pb.UpdateEntityCommand) (command.Command, error) - -func RegisterUpdateDecoder[T models.Entity](managers *Managers, updater command.EntityUpdater[T]) { - entityType := updater.GetEntityTypeId() - managers.Registry.RegisterSingleton(entityType+UpdateDecoder, updateDecoderF(func(cmd *cmd_pb.UpdateEntityCommand) (command.Command, error) { - entity, err := updater.Unmarshall(cmd.EntityData) - if err != nil { - return nil, err - } - return &command.UpdateEntityCommand[T]{ - Context: change.FromProtoBuf(cmd.Ctx), - Entity: entity, - Updater: updater, - UpdatedFields: fields.SliceToUpdatedFields(cmd.UpdatedFields), - Flags: cmd.Flags, - }, nil - })) -} - -type deleteDecoderF func(cmd *cmd_pb.DeleteEntityCommand) (command.Command, error) - -func RegisterDeleteDecoder(managers *Managers, deleter command.EntityDeleter) { - entityType := deleter.GetEntityTypeId() - managers.Registry.RegisterSingleton(entityType+DeleteDecoder, deleteDecoderF(func(cmd *cmd_pb.DeleteEntityCommand) (command.Command, error) { - return &command.DeleteEntityCommand{ - Context: change.FromProtoBuf(cmd.Ctx), - Deleter: deleter, - Id: cmd.EntityId, - }, nil - })) -} - -func RegisterManagerDecoder[T models.Entity](managers *Managers, ctrl command.EntityManager[T]) { - RegisterCreateDecoder[T](managers, ctrl) - RegisterUpdateDecoder[T](managers, ctrl) - RegisterDeleteDecoder(managers, ctrl) -} - -func NewManagers(network *Network, dispatcher command.Dispatcher, db boltz.Db, stores *db.Stores, routerCommPool goroutines.Pool) *Managers { - result := &Managers{ - network: network, - db: db, - stores: stores, - Dispatcher: dispatcher, - Registry: ioc.NewRegistry(), - } - result.Command = newCommandManager(result) - result.Terminators = newTerminatorManager(result) - result.Routers = newRouterManager(result) - result.Services = newServiceManager(result) - result.Inspections = NewInspectionsManager(network) - if result.Dispatcher == nil { - devVersion := versions.MustParseSemVer("0.0.0") - version := versions.MustParseSemVer(network.VersionProvider.Version()) - result.Dispatcher = &command.LocalDispatcher{ - EncodeDecodeCommands: devVersion.Equals(version), - } - } - result.Command.registerGenericCommands() - - result.RouterMessaging = NewRouterMessaging(result, routerCommPool) - - RegisterManagerDecoder[*Service](result, result.Services) - RegisterManagerDecoder[*Router](result, result.Routers) - RegisterManagerDecoder[*Terminator](result, result.Terminators) - RegisterCommand(result, &DeleteTerminatorsBatchCommand{}, &cmd_pb.DeleteTerminatorsBatchCommand{}) - - return result -} - -type Controller[T models.Entity] interface { - models.EntityRetriever[T] - getManagers() *Managers -} - -func newBaseEntityManager[ME models.Entity, PE boltz.ExtEntity](managers *Managers, store boltz.EntityStore[PE], newModelEntity func() ME) baseEntityManager[ME, PE] { - return baseEntityManager[ME, PE]{ - BaseEntityManager: models.BaseEntityManager[PE]{ - Store: store, - }, - Managers: managers, - newModelEntity: newModelEntity, - } -} - -type baseEntityManager[T models.Entity, PE boltz.ExtEntity] struct { - models.BaseEntityManager[PE] - *Managers - newModelEntity func() T - populateEntity func(entity T, tx *bbolt.Tx, boltEntity boltz.Entity) error -} - -func (self *baseEntityManager[ME, PE]) GetEntityTypeId() string { - // default this to the store entity type and let individual managers override it where - // needed to avoid collisions (e.g. edge service/router) - return self.GetStore().GetEntityType() -} - -func (self *baseEntityManager[ME, PE]) Delete(id string, ctx *change.Context) error { - cmd := &command.DeleteEntityCommand{ - Context: ctx, - Deleter: self, - Id: id, - } - return self.Managers.Dispatch(cmd) -} - -func (self *baseEntityManager[ME, PE]) ApplyDelete(cmd *command.DeleteEntityCommand, ctx boltz.MutateContext) error { - return self.db.Update(ctx, func(mutateCtx boltz.MutateContext) error { - return self.Store.DeleteById(ctx, cmd.Id) - }) -} - -func (ctrl *baseEntityManager[ME, PE]) BaseLoad(id string) (ME, error) { - entity := ctrl.newModelEntity() - if err := ctrl.readEntity(id, entity); err != nil { - return *new(ME), err - } - return entity, nil -} - -func (ctrl *baseEntityManager[ME, PE]) BaseLoadInTx(tx *bbolt.Tx, id string) (ME, error) { - entity := ctrl.newModelEntity() - if err := ctrl.readEntityInTx(tx, id, entity); err != nil { - return *new(ME), err - } - return entity, nil -} - -func (ctrl *baseEntityManager[ME, PE]) readEntity(id string, modelEntity ME) error { - return ctrl.db.View(func(tx *bbolt.Tx) error { - return ctrl.readEntityInTx(tx, id, modelEntity) - }) -} - -func (ctrl *baseEntityManager[ME, PE]) readEntityInTx(tx *bbolt.Tx, id string, modelEntity ME) error { - boltEntity, found, err := ctrl.GetStore().FindById(tx, id) - if err != nil { - return err - } - if !found { - return boltz.NewNotFoundError(ctrl.GetStore().GetSingularEntityType(), "id", id) - } - - return ctrl.populateEntity(modelEntity, tx, boltEntity) -} - -func (ctrl *baseEntityManager[ME, PE]) BaseList(query string) (*models.EntityListResult[ME], error) { - result := &models.EntityListResult[ME]{Loader: ctrl} - err := ctrl.ListWithHandler(query, result.Collect) - if err != nil { - return nil, err - } - return result, nil -} - -func (ctrl *baseEntityManager[ME, PE]) ListWithHandler(queryString string, resultHandler models.ListResultHandler) error { - return ctrl.db.View(func(tx *bbolt.Tx) error { - return ctrl.ListWithTx(tx, queryString, resultHandler) - }) -} - -func (ctrl *baseEntityManager[ME, PE]) BasePreparedList(query ast.Query) (*models.EntityListResult[ME], error) { - result := &models.EntityListResult[ME]{Loader: ctrl} - err := ctrl.PreparedListWithHandler(query, result.Collect) - if err != nil { - return nil, err - } - return result, nil -} - -func (ctrl *baseEntityManager[ME, PE]) PreparedListWithHandler(query ast.Query, resultHandler models.ListResultHandler) error { - return ctrl.db.View(func(tx *bbolt.Tx) error { - return ctrl.PreparedListWithTx(tx, query, resultHandler) - }) -} - -func (ctrl *baseEntityManager[ME, PE]) PreparedListAssociatedWithHandler(id string, association string, query ast.Query, handler models.ListResultHandler) error { - return ctrl.db.View(func(tx *bbolt.Tx) error { - return ctrl.PreparedListAssociatedWithTx(tx, id, association, query, handler) - }) -} - -type boltEntitySource[PE boltz.ExtEntity] interface { - models.Entity - toBolt() PE -} - -func (ctrl *baseEntityManager[ME, PE]) updateGeneral(ctx boltz.MutateContext, modelEntity boltEntitySource[PE], checker boltz.FieldChecker) error { - return ctrl.db.Update(ctx, func(ctx boltz.MutateContext) error { - existing, found, err := ctrl.GetStore().FindById(ctx.Tx(), modelEntity.GetId()) - if err != nil { - return err - } - if !found { - return boltz.NewNotFoundError(ctrl.GetStore().GetSingularEntityType(), "id", modelEntity.GetId()) - } - - boltEntity := modelEntity.toBolt() - - if err := ctrl.ValidateNameOnUpdate(ctx, boltEntity, existing, checker); err != nil { - return err - } - - if err := ctrl.GetStore().Update(ctx, boltEntity, checker); err != nil { - pfxlog.Logger().WithError(err).Errorf("could not update %v entity", ctrl.GetStore().GetEntityType()) - return err - } - return nil - }) -} diff --git a/controller/network/network.go b/controller/network/network.go index 60e2469d1..8b790e13a 100644 --- a/controller/network/network.go +++ b/controller/network/network.go @@ -24,10 +24,17 @@ import ( "github.com/openziti/foundation/v2/concurrenz" "github.com/openziti/foundation/v2/goroutines" "github.com/openziti/storage/objectz" + "github.com/openziti/ziti/common/inspect" fabricMetrics "github.com/openziti/ziti/common/metrics" + "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/openziti/ziti/common/pb/mgmt_pb" + "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/idgen" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/raft" + "google.golang.org/protobuf/proto" + "math" "os" "path/filepath" "runtime/debug" @@ -65,7 +72,7 @@ const SmartRerouteAttempt = 99969996 type Config interface { GetId() *identity.TokenId GetMetricsRegistry() metrics.Registry - GetOptions() *Options + GetOptions() *config.NetworkConfig GetCommandDispatcher() command.Dispatcher GetDb() boltz.Db GetVersionProvider() versions.VersionProvider @@ -76,13 +83,13 @@ type Config interface { type InspectTarget func(string) (bool, *string, error) type Network struct { - *Managers + *model.Managers + env model.Env nodeId string - options *Options + options *config.NetworkConfig assembleAndCleanC chan struct{} - linkController *linkController forwardingFaults chan struct{} - circuitController *circuitController + circuitIdGenerator idgen.Generator routeSenderController *routeSenderController sequence *sequence.Sequence eventDispatcher event.Dispatcher @@ -110,15 +117,12 @@ type Network struct { config Config + Inspections *InspectionsManager + RouterMessaging *RouterMessaging inspectionTargets concurrenz.CopyOnWriteSlice[InspectTarget] } -func NewNetwork(config Config) (*Network, error) { - stores, err := db.InitStores(config.GetDb(), config.GetCommandDispatcher().GetRateLimiter()) - if err != nil { - return nil, err - } - +func NewNetwork(config Config, env model.Env) (*Network, error) { if config.GetOptions().IntervalAgeThreshold != 0 { metrics.SetIntervalAgeThreshold(config.GetOptions().IntervalAgeThreshold) logrus.Infof("set interval age threshold to '%v'", config.GetOptions().IntervalAgeThreshold) @@ -126,12 +130,13 @@ func NewNetwork(config Config) (*Network, error) { serviceEventMetrics := metrics.NewUsageRegistry(config.GetId().Token, nil, config.GetCloseNotify()) network := &Network{ + env: env, + Managers: env.GetManagers(), nodeId: config.GetId().Token, options: config.GetOptions(), assembleAndCleanC: make(chan struct{}, 1), - linkController: newLinkController(config.GetOptions()), forwardingFaults: make(chan struct{}, 1), - circuitController: newCircuitController(), + circuitIdGenerator: idgen.NewGenerator(), routeSenderController: newRouteSenderController(), sequence: sequence.NewSequence(), eventDispatcher: config.GetEventDispatcher(), @@ -157,22 +162,46 @@ func NewNetwork(config Config) (*Network, error) { config: config, } + env.GetManagers().Command.Decoders.RegisterF(int32(cmd_pb.CommandType_SyncSnapshot), network.decodeSyncSnapshotCommand) + routerCommPool, err := network.createRouterCommPool(config) if err != nil { return nil, err } - network.Managers = NewManagers(network, config.GetCommandDispatcher(), config.GetDb(), stores, routerCommPool) - network.Managers.Inspections.network = network + network.Inspections = NewInspectionsManager(network) + network.RouterMessaging = NewRouterMessaging(env, routerCommPool) + + env.GetManagers().Router.Store.AddEntityIdListener(network.HandleRouterDelete, boltz.EntityDeletedAsync) network.AddCapability("ziti.fabric") network.showOptions() network.relayControllerMetrics() - network.AddRouterPresenceHandler(network.Managers.RouterMessaging) - go network.Managers.RouterMessaging.run() + network.AddRouterPresenceHandler(network.RouterMessaging) + go network.RouterMessaging.run() return network, nil } +func (self *Network) HandleRouterDelete(id string) { + self.routerDeleted(id) + self.RouterMessaging.RouterDeleted(id) +} + +func (self *Network) decodeSyncSnapshotCommand(_ int32, data []byte) (command.Command, error) { + msg := &cmd_pb.SyncSnapshotCommand{} + if err := proto.Unmarshal(data, msg); err != nil { + return nil, err + } + + cmd := &command.SyncSnapshotCommand{ + SnapshotId: msg.SnapshotId, + Snapshot: msg.Snapshot, + SnapshotSink: self.RestoreSnapshot, + } + + return cmd, nil +} + func (network *Network) createRouterCommPool(config Config) (goroutines.Pool, error) { poolConfig := goroutines.PoolConfig{ QueueSize: config.GetOptions().RouterComm.QueueSize, @@ -220,69 +249,65 @@ func (network *Network) GetAppId() string { return network.nodeId } -func (network *Network) GetOptions() *Options { +func (network *Network) GetOptions() *config.NetworkConfig { return network.options } func (network *Network) GetDb() boltz.Db { - return network.db + return network.config.GetDb() } func (network *Network) GetStores() *db.Stores { - return network.stores -} - -func (network *Network) GetManagers() *Managers { - return network.Managers + return network.env.GetStores() } -func (network *Network) GetConnectedRouter(routerId string) *Router { - return network.Routers.getConnected(routerId) +func (network *Network) GetConnectedRouter(routerId string) *model.Router { + return network.Router.GetConnected(routerId) } -func (network *Network) GetReloadedRouter(routerId string) (*Router, error) { - network.Routers.RemoveFromCache(routerId) - return network.Routers.Read(routerId) +func (network *Network) GetReloadedRouter(routerId string) (*model.Router, error) { + network.Router.RemoveFromCache(routerId) + return network.Router.Read(routerId) } -func (network *Network) GetRouter(routerId string) (*Router, error) { - return network.Routers.Read(routerId) +func (network *Network) GetRouter(routerId string) (*model.Router, error) { + return network.Router.Read(routerId) } -func (network *Network) AllConnectedRouters() []*Router { - return network.Routers.allConnected() +func (network *Network) AllConnectedRouters() []*model.Router { + return network.Router.AllConnected() } -func (network *Network) GetLink(linkId string) (*Link, bool) { - return network.linkController.get(linkId) +func (network *Network) GetLink(linkId string) (*model.Link, bool) { + return network.Link.Get(linkId) } -func (network *Network) GetAllLinks() []*Link { - return network.linkController.all() +func (network *Network) GetAllLinks() []*model.Link { + return network.Link.All() } -func (network *Network) GetAllLinksForRouter(routerId string) []*Link { +func (network *Network) GetAllLinksForRouter(routerId string) []*model.Link { r := network.GetConnectedRouter(routerId) if r == nil { return nil } - return r.routerLinks.GetLinks() + return r.GetLinks() } -func (network *Network) GetCircuit(circuitId string) (*Circuit, bool) { - return network.circuitController.get(circuitId) +func (network *Network) GetCircuit(circuitId string) (*model.Circuit, bool) { + return network.Circuit.Get(circuitId) } -func (network *Network) GetAllCircuits() []*Circuit { - return network.circuitController.all() +func (network *Network) GetAllCircuits() []*model.Circuit { + return network.Circuit.All() } -func (network *Network) GetCircuitStore() *objectz.ObjectStore[*Circuit] { - return network.circuitController.store +func (network *Network) GetCircuitStore() *objectz.ObjectStore[*model.Circuit] { + return network.Circuit.GetStore() } -func (network *Network) GetLinkStore() *objectz.ObjectStore[*Link] { - return network.linkController.store +func (network *Network) GetLinkStore() *objectz.ObjectStore[*model.Link] { + return network.Link.GetStore() } func (network *Network) RouteResult(rs *RouteStatus) bool { @@ -290,7 +315,7 @@ func (network *Network) RouteResult(rs *RouteStatus) bool { } func (network *Network) newRouteSender(circuitId string) *routeSender { - rs := newRouteSender(circuitId, network.options.RouteTimeout, network, network.Terminators) + rs := newRouteSender(circuitId, network.options.RouteTimeout, network, network.Terminator) network.routeSenderController.addRouteSender(rs) return rs } @@ -320,12 +345,12 @@ func (network *Network) GetCloseNotify() <-chan struct{} { } func (network *Network) ConnectedRouter(id string) bool { - return network.Routers.IsConnected(id) + return network.Router.IsConnected(id) } -func (network *Network) ConnectRouter(r *Router) { - network.linkController.buildRouterLinks(r) - network.Routers.markConnected(r) +func (network *Network) ConnectRouter(r *model.Router) { + network.Link.BuildRouterLinks(r) + network.Router.MarkConnected(r) time.AfterFunc(250*time.Millisecond, network.notifyAssembleAndClean) @@ -335,9 +360,9 @@ func (network *Network) ConnectRouter(r *Router) { go network.ValidateTerminators(r) } -func (network *Network) ValidateTerminators(r *Router) { +func (network *Network) ValidateTerminators(r *model.Router) { logger := pfxlog.Logger().WithField("routerId", r.Id) - result, err := network.Terminators.Query(fmt.Sprintf(`router.id = "%v" limit none`, r.Id)) + result, err := network.Terminator.Query(fmt.Sprintf(`router.id = "%v" limit none`, r.Id)) if err != nil { logger.WithError(err).Error("failed to get terminators for router") return @@ -348,13 +373,13 @@ func (network *Network) ValidateTerminators(r *Router) { return } - network.Managers.RouterMessaging.ValidateRouterTerminators(result.Entities) + network.RouterMessaging.ValidateRouterTerminators(result.Entities) } type LinkValidationCallback func(detail *mgmt_pb.RouterLinkDetails) func (n *Network) ValidateLinks(filter string, cb LinkValidationCallback) (int64, func(), error) { - result, err := n.Routers.BaseList(filter) + result, err := n.Router.BaseList(filter) if err != nil { return 0, nil, err } @@ -368,10 +393,10 @@ func (n *Network) ValidateLinks(filter string, cb LinkValidationCallback) (int64 sem.Acquire() go func() { defer sem.Release() - n.linkController.ValidateRouterLinks(n, connectedRouter, cb) + n.ValidateRouterLinks(connectedRouter, cb) }() } else { - n.linkController.reportRouterLinksError(router, errors.New("router not connected"), cb) + n.reportRouterLinksError(router, errors.New("router not connected"), cb) } } } @@ -382,7 +407,7 @@ func (n *Network) ValidateLinks(filter string, cb LinkValidationCallback) (int64 type SdkTerminatorValidationCallback func(detail *mgmt_pb.RouterSdkTerminatorsDetails) func (n *Network) ValidateRouterSdkTerminators(filter string, cb SdkTerminatorValidationCallback) (int64, func(), error) { - result, err := n.Routers.BaseList(filter) + result, err := n.Router.BaseList(filter) if err != nil { return 0, nil, err } @@ -396,10 +421,10 @@ func (n *Network) ValidateRouterSdkTerminators(filter string, cb SdkTerminatorVa sem.Acquire() go func() { defer sem.Release() - n.Routers.ValidateRouterSdkTerminators(connectedRouter, cb) + n.Router.ValidateRouterSdkTerminators(connectedRouter, cb) }() } else { - n.Routers.reportRouterSdkTerminatorsError(router, errors.New("router not connected"), cb) + n.Router.ReportRouterSdkTerminatorsError(router, errors.New("router not connected"), cb) } } } @@ -407,19 +432,19 @@ func (n *Network) ValidateRouterSdkTerminators(filter string, cb SdkTerminatorVa return int64(len(result.Entities)), evalF, nil } -func (network *Network) DisconnectRouter(r *Router) { +func (network *Network) DisconnectRouter(r *model.Router) { // 1: remove Links for Router - for _, l := range r.routerLinks.GetLinks() { - wasConnected := l.CurrentState().Mode == Connected + for _, l := range r.GetLinks() { + wasConnected := l.CurrentState().Mode == model.Connected if l.Src.Id == r.Id { - network.linkController.remove(l) + network.Link.Remove(l) } if wasConnected { network.RerouteLink(l) } } // 2: remove Router - network.Routers.markDisconnected(r) + network.Router.MarkDisconnected(r) for _, h := range network.routerPresenceHandlers { h.RouterDisconnected(r) @@ -435,14 +460,14 @@ func (network *Network) notifyAssembleAndClean() { } } -func (network *Network) NotifyExistingLink(id string, iteration uint32, linkProtocol, dialAddress string, srcRouter *Router, dstRouterId string) { +func (network *Network) NotifyExistingLink(id string, iteration uint32, linkProtocol, dialAddress string, srcRouter *model.Router, dstRouterId string) { log := pfxlog.Logger(). WithField("routerId", srcRouter.Id). WithField("linkId", id). WithField("destRouterId", dstRouterId). WithField("iteration", iteration) - src := network.Routers.getConnected(srcRouter.Id) + src := network.Router.GetConnected(srcRouter.Id) if src == nil { log.Info("ignoring links message processed after router disconnected") return @@ -453,12 +478,12 @@ func (network *Network) NotifyExistingLink(id string, iteration uint32, linkProt return } - dst := network.Routers.getConnected(dstRouterId) + dst := network.Router.GetConnected(dstRouterId) if dst == nil { network.NotifyLinkIdEvent(id, event.LinkFromRouterDisconnectedDest) } - link, created := network.linkController.routerReportedLink(id, iteration, linkProtocol, dialAddress, srcRouter, dst, dstRouterId) + link, created := network.Link.RouterReportedLink(id, iteration, linkProtocol, dialAddress, srcRouter, dst, dstRouterId) if created { network.NotifyLinkEvent(link, event.LinkFromRouterNew) log.Info("router reported link added") @@ -469,27 +494,27 @@ func (network *Network) NotifyExistingLink(id string, iteration uint32, linkProt } func (network *Network) LinkConnected(msg *ctrl_pb.LinkConnected) error { - if l, found := network.linkController.get(msg.Id); found { - if state := l.CurrentState(); state.Mode != Pending { + if l, found := network.Link.Get(msg.Id); found { + if state := l.CurrentState(); state.Mode != model.Pending { return errors.Errorf("link [l/%v] state is %v, not pending, cannot mark connected", msg.Id, state.Mode) } - l.SetState(Connected) + l.SetState(model.Connected) network.NotifyLinkConnected(l, msg) return nil } return errors.Errorf("no such link [l/%s]", msg.Id) } -func (network *Network) LinkFaulted(l *Link, dupe bool) error { - l.SetState(Failed) +func (network *Network) LinkFaulted(l *model.Link, dupe bool) error { + l.SetState(model.Failed) if dupe { network.NotifyLinkEvent(l, event.LinkDuplicate) } else { network.NotifyLinkEvent(l, event.LinkFault) } pfxlog.Logger().WithField("linkId", l.Id).Info("removing failed link") - network.linkController.remove(l) + network.Link.Remove(l) return nil } @@ -513,14 +538,14 @@ func (network *Network) VerifyRouter(routerId string, fingerprints []string) err return errors.Errorf("could not verify fingerprint for router %v", routerId) } -func (network *Network) RerouteLink(l *Link) { +func (network *Network) RerouteLink(l *model.Link) { // This is called from Channel.rxer() and thus may not block go func() { network.handleRerouteLink(l) }() } -func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, error) { +func (network *Network) CreateCircuit(params model.CreateCircuitParams) (*model.Circuit, error) { clientId := params.GetClientId() service := params.GetServiceId() ctx := params.GetLogContext() @@ -531,7 +556,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err instanceId, serviceId := parseInstanceIdAndService(service) // 1: Allocate Circuit Identifier - circuitId, err := network.circuitController.nextCircuitId() + circuitId, err := network.circuitIdGenerator.NextAlphaNumericPrefixedId() if err != nil { network.CircuitFailedEvent(circuitId, params, startTime, nil, nil, CircuitFailureInvalidService) return nil, err @@ -549,7 +574,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err defer func() { network.removeRouteSender(rs) }() for { // 2: Find Service - svc, err := network.Services.Read(serviceId) + svc, err := network.Service.Read(serviceId) if err != nil { network.CircuitFailedEvent(circuitId, params, startTime, nil, nil, CircuitFailureInvalidService) network.ServiceDialOtherError(serviceId) @@ -577,7 +602,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err tags := params.GetCircuitTags(terminator) // 4a: Create Route Messages - rms := path.CreateRouteMessages(attempt, circuitId, terminator, deadline) + rms := network.CreateRouteMessages(path, attempt, circuitId, terminator, deadline) rms[len(rms)-1].Egress.PeerData = clientId.Data for _, msg := range rms { msg.Context = &ctrl_pb.Context{ @@ -658,7 +683,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err now := time.Now() // 6: Create Circuit Object - circuit := &Circuit{ + circuit := &model.Circuit{ Id: circuitId, ClientId: clientId.Token, ServiceId: svc.Id, @@ -669,7 +694,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err UpdatedAt: now, Tags: tags, } - network.circuitController.add(circuit) + network.Circuit.Add(circuit) creationTimespan := time.Since(startTime) network.CircuitEvent(event.CircuitCreated, circuit, &creationTimespan) @@ -700,7 +725,7 @@ func parseInstanceIdAndService(service string) (string, string) { return identityId, serviceId } -func (network *Network) selectPath(params CreateCircuitParams, svc *Service, instanceId string, ctx logcontext.Context) (xt.Strategy, xt.CostedTerminator, []*Router, xt.PeerData, CircuitError) { +func (network *Network) selectPath(params model.CreateCircuitParams, svc *model.Service, instanceId string, ctx logcontext.Context) (xt.Strategy, xt.CostedTerminator, []*model.Router, xt.PeerData, CircuitError) { paths := map[string]*PathAndCost{} var weightedTerminators []xt.CostedTerminator var errList []error @@ -717,7 +742,7 @@ func (network *Network) selectPath(params CreateCircuitParams, svc *Service, ins pathAndCost, found := paths[terminator.Router] if !found { - dstR := network.Routers.getConnected(terminator.GetRouterId()) + dstR := network.Router.GetConnected(terminator.GetRouterId()) if dstR == nil { err := errors.Errorf("router with id=%v on terminator with id=%v for service name=%v is not online", terminator.GetRouterId(), terminator.GetId(), svc.Name) @@ -743,7 +768,7 @@ func (network *Network) selectPath(params CreateCircuitParams, svc *Service, ins dynamicCost := xt.GlobalCosts().GetDynamicCost(terminator.Id) unbiasedCost := uint32(terminator.Cost) + uint32(dynamicCost) + pathAndCost.cost biasedCost := terminator.Precedence.GetBiasedCost(unbiasedCost) - costedTerminator := &RoutingTerminator{ + costedTerminator := &model.RoutingTerminator{ Terminator: terminator, RouteCost: biasedCost, } @@ -812,7 +837,7 @@ func (network *Network) selectPath(params CreateCircuitParams, svc *Service, ins func (network *Network) RemoveCircuit(circuitId string, now bool) error { log := pfxlog.Logger().WithField("circuitId", circuitId) - if circuit, found := network.circuitController.get(circuitId); found { + if circuit, found := network.Circuit.Get(circuitId); found { for _, r := range circuit.Path.Nodes { err := sendUnroute(r, circuit.Id, now) if err != nil { @@ -820,10 +845,10 @@ func (network *Network) RemoveCircuit(circuitId string, now bool) error { } } - network.circuitController.remove(circuit) + network.Circuit.Remove(circuit) network.CircuitEvent(event.CircuitDeleted, circuit, nil) - if svc, err := network.Services.Read(circuit.ServiceId); err == nil { + if svc, err := network.Service.Read(circuit.ServiceId); err == nil { if strategy, err := network.strategyRegistry.GetStrategy(svc.TerminatorStrategy); strategy != nil { strategy.NotifyEvent(xt.NewCircuitRemoved(circuit.Terminator)) } else if err != nil { @@ -840,7 +865,7 @@ func (network *Network) RemoveCircuit(circuitId string, now bool) error { return InvalidCircuitError{circuitId: circuitId} } -func (network *Network) CreatePath(srcR, dstR *Router) (*Path, error) { +func (network *Network) CreatePath(srcR, dstR *model.Router) (*model.Path, error) { ingressId, err := network.sequence.NextHash() if err != nil { return nil, err @@ -851,11 +876,11 @@ func (network *Network) CreatePath(srcR, dstR *Router) (*Path, error) { return nil, err } - path := &Path{ - Links: make([]*Link, 0), + path := &model.Path{ + Links: make([]*model.Link, 0), IngressId: ingressId, EgressId: egressId, - Nodes: make([]*Router, 0), + Nodes: make([]*model.Router, 0), } path.Nodes = append(path.Nodes, srcR) path.Nodes = append(path.Nodes, dstR) @@ -863,55 +888,10 @@ func (network *Network) CreatePath(srcR, dstR *Router) (*Path, error) { return network.UpdatePath(path) } -func (network *Network) CreatePathWithNodes(nodes []*Router) (*Path, CircuitError) { - ingressId, err := network.sequence.NextHash() - if err != nil { - return nil, newCircuitErrWrap(CircuitFailureIdGenerationError, err) - } - - egressId, err := network.sequence.NextHash() - if err != nil { - return nil, newCircuitErrWrap(CircuitFailureIdGenerationError, err) - } - - path := &Path{ - Nodes: nodes, - IngressId: ingressId, - EgressId: egressId, - } - if err := network.setLinks(path); err != nil { - return nil, newCircuitErrWrap(CircuitFailurePathMissingLink, err) - } - return path, nil -} - -func (network *Network) UpdatePath(path *Path) (*Path, error) { - srcR := path.Nodes[0] - dstR := path.Nodes[len(path.Nodes)-1] - nodes, _, err := network.shortestPath(srcR, dstR) - if err != nil { - return nil, err - } - - path2 := &Path{ - Nodes: nodes, - IngressId: path.IngressId, - EgressId: path.EgressId, - InitiatorLocalAddr: path.InitiatorLocalAddr, - InitiatorRemoteAddr: path.InitiatorRemoteAddr, - TerminatorLocalAddr: path.TerminatorLocalAddr, - TerminatorRemoteAddr: path.TerminatorRemoteAddr, - } - if err := network.setLinks(path2); err != nil { - return nil, err - } - return path2, nil -} - -func (network *Network) setLinks(path *Path) error { +func (network *Network) setLinks(path *model.Path) error { if len(path.Nodes) > 1 { for i := 0; i < len(path.Nodes)-1; i++ { - if link, found := network.linkController.leastExpensiveLink(path.Nodes[i], path.Nodes[i+1]); found { + if link, found := network.Link.LeastExpensiveLink(path.Nodes[i], path.Nodes[i+1]); found { path.Links = append(path.Links, link) } else { return errors.Errorf("no link from r/%v to r/%v", path.Nodes[i].Id, path.Nodes[i+1].Id) @@ -947,7 +927,7 @@ func (network *Network) Run() { network.assemble() network.clean() network.smart() - network.linkController.scanForDeadLinks() + network.Link.ScanForDeadLinks() case <-network.closeNotify: network.eventDispatcher.RemoveMetricsMessageHandler(network) @@ -990,10 +970,10 @@ func (network *Network) watchdog() { } } -func (network *Network) handleRerouteLink(l *Link) { +func (network *Network) handleRerouteLink(l *model.Link) { log := logrus.WithField("linkId", l.Id) log.Info("changed link") - if err := network.rerouteLink(l, time.Now().Add(DefaultOptionsRouteTimeout)); err != nil { + if err := network.rerouteLink(l, time.Now().Add(config.DefaultOptionsRouteTimeout)); err != nil { log.WithError(err).Error("unexpected error rerouting link") } } @@ -1017,13 +997,13 @@ func (network *Network) GetCapabilities() []string { func (network *Network) RemoveLink(linkId string) { log := pfxlog.Logger().WithField("linkId", linkId) - link, _ := network.linkController.get(linkId) + link, _ := network.Link.Get(linkId) var iteration uint32 - var routerList []*Router + var routerList []*model.Router if link != nil { iteration = link.Iteration - routerList = []*Router{link.Src} + routerList = []*model.Router{link.Src} if dst := link.GetDest(); dst != nil { routerList = append(routerList, dst) } @@ -1055,15 +1035,15 @@ func (network *Network) RemoveLink(linkId string) { } if link != nil { - network.linkController.remove(link) + network.Link.Remove(link) network.RerouteLink(link) } } -func (network *Network) rerouteLink(l *Link, deadline time.Time) error { - circuits := network.circuitController.all() +func (network *Network) rerouteLink(l *model.Link, deadline time.Time) error { + circuits := network.Circuit.All() for _, circuit := range circuits { - if circuit.Path.usesLink(l) { + if circuit.Path.UsesLink(l) { log := logrus.WithField("linkId", l.Id). WithField("circuitId", circuit.Id) log.Info("circuit uses link") @@ -1079,11 +1059,11 @@ func (network *Network) rerouteLink(l *Link, deadline time.Time) error { return nil } -func (network *Network) rerouteCircuitWithTries(circuit *Circuit, retries int) bool { +func (network *Network) rerouteCircuitWithTries(circuit *model.Circuit, retries int) bool { log := pfxlog.Logger().WithField("circuitId", circuit.Id) for i := 0; i < retries; i++ { - deadline := time.Now().Add(DefaultOptionsRouteTimeout) + deadline := time.Now().Add(config.DefaultOptionsRouteTimeout) err := network.rerouteCircuit(circuit, deadline) if err == nil { return true @@ -1098,7 +1078,7 @@ func (network *Network) rerouteCircuitWithTries(circuit *Circuit, retries int) b return false } -func (network *Network) rerouteCircuit(circuit *Circuit, deadline time.Time) error { +func (network *Network) rerouteCircuit(circuit *model.Circuit, deadline time.Time) error { log := pfxlog.Logger().WithField("circuitId", circuit.Id) if circuit.Rerouting.CompareAndSwap(false, true) { defer circuit.Rerouting.Store(false) @@ -1109,7 +1089,7 @@ func (network *Network) rerouteCircuit(circuit *Circuit, deadline time.Time) err circuit.Path = cq circuit.UpdatedAt = time.Now() - rms := cq.CreateRouteMessages(SmartRerouteAttempt, circuit.Id, circuit.Terminator, deadline) + rms := network.CreateRouteMessages(cq, SmartRerouteAttempt, circuit.Id, circuit.Terminator, deadline) for i := 0; i < len(cq.Nodes); i++ { if _, err := sendRoute(cq.Nodes[i], rms[i], network.options.RouteTimeout); err != nil { @@ -1130,7 +1110,7 @@ func (network *Network) rerouteCircuit(circuit *Circuit, deadline time.Time) err } } -func (network *Network) smartReroute(circuit *Circuit, cq *Path, deadline time.Time) bool { +func (network *Network) smartReroute(circuit *model.Circuit, cq *model.Path, deadline time.Time) bool { retry := false log := pfxlog.Logger().WithField("circuitId", circuit.Id) if circuit.Rerouting.CompareAndSwap(false, true) { @@ -1139,7 +1119,7 @@ func (network *Network) smartReroute(circuit *Circuit, cq *Path, deadline time.T circuit.Path = cq circuit.UpdatedAt = time.Now() - rms := cq.CreateRouteMessages(SmartRerouteAttempt, circuit.Id, circuit.Terminator, deadline) + rms := network.CreateRouteMessages(cq, SmartRerouteAttempt, circuit.Id, circuit.Terminator, deadline) for i := 0; i < len(cq.Nodes); i++ { if _, err := sendRoute(cq.Nodes[i], rms[i], network.options.RouteTimeout); err != nil { @@ -1164,7 +1144,7 @@ func (network *Network) AcceptMetricsMsg(metrics *metrics_pb.MetricsMessage) { log := pfxlog.Logger() - router, err := network.Routers.Read(metrics.SourceId) + router, err := network.Router.Read(metrics.SourceId) if err != nil { log.Debugf("could not find router [r/%s] while processing metrics", metrics.SourceId) return @@ -1196,7 +1176,7 @@ func (network *Network) AcceptMetricsMsg(metrics *metrics_pb.MetricsMessage) { } } -func sendRoute(r *Router, createMsg *ctrl_pb.Route, timeout time.Duration) (xt.PeerData, error) { +func sendRoute(r *model.Router, createMsg *ctrl_pb.Route, timeout time.Duration) (xt.PeerData, error) { log := pfxlog.Logger().WithField("routerId", r.Id). WithField("circuitId", createMsg.CircuitId) @@ -1230,7 +1210,7 @@ func sendRoute(r *Router, createMsg *ctrl_pb.Route, timeout time.Duration) (xt.P return nil, fmt.Errorf("unexpected response type %v received in reply to route request", msg.ContentType) } -func sendUnroute(r *Router, circuitId string, now bool) error { +func sendUnroute(r *model.Router, circuitId string, now bool) error { unroute := &ctrl_pb.Unroute{ CircuitId: circuitId, Now: now, @@ -1279,7 +1259,7 @@ func (network *Network) Inspect(name string) (*string, error) { } } else if lc == "connected-routers" { var result []map[string]any - for _, r := range network.Routers.allConnected() { + for _, r := range network.Router.AllConnected() { status := map[string]any{} status["Id"] = r.Id status["Name"] = r.Name @@ -1304,7 +1284,7 @@ func (network *Network) Inspect(name string) (*string, error) { return &resultStr, nil } } else if lc == "router-messaging" { - routerMessagingState, err := network.Managers.RouterMessaging.Inspect() + routerMessagingState, err := network.RouterMessaging.Inspect() if err != nil { return nil, err } @@ -1428,7 +1408,7 @@ func (network *Network) SnapshotDatabaseToFile(path string) (string, error) { func (network *Network) RestoreSnapshot(cmd *command.SyncSnapshotCommand) error { log := pfxlog.Logger() - currentSnapshotId, err := network.getDb().GetSnapshotId() + currentSnapshotId, err := network.GetDb().GetSnapshotId() if err != nil { log.WithError(err).Error("unable to get current snapshot id") } @@ -1443,14 +1423,14 @@ func (network *Network) RestoreSnapshot(cmd *command.SyncSnapshotCommand) error return errors.Wrapf(err, "unable to create gz reader for reading migration snapshot during restore") } - network.getDb().RestoreFromReader(reader) + network.GetDb().RestoreFromReader(reader) return nil } func (network *Network) SnapshotToRaft() error { buf := &bytes.Buffer{} gzWriter := gzip.NewWriter(buf) - snapshotId, err := network.db.SnapshotToWriter(gzWriter) + snapshotId, err := network.GetDb().SnapshotToWriter(gzWriter) if err != nil { return err } @@ -1465,18 +1445,129 @@ func (network *Network) SnapshotToRaft() error { SnapshotSink: network.RestoreSnapshot, } - return network.Dispatch(cmd) + return network.Managers.Dispatcher.Dispatch(cmd) } func (network *Network) AddInspectTarget(target InspectTarget) { network.inspectionTargets.Append(target) } +func (network *Network) ValidateRouterLinks(router *model.Router, cb LinkValidationCallback) { + request := &ctrl_pb.InspectRequest{RequestedValues: []string{"links"}} + resp := &ctrl_pb.InspectResponse{} + respMsg, err := protobufs.MarshalTyped(request).WithTimeout(time.Minute).SendForReply(router.Control) + if err = protobufs.TypedResponse(resp).Unmarshall(respMsg, err); err != nil { + network.reportRouterLinksError(router, err, cb) + return + } + + var linkDetails *inspect.LinksInspectResult + for _, val := range resp.Values { + if val.Name == "links" { + if err = json.Unmarshal([]byte(val.Value), &linkDetails); err != nil { + network.reportRouterLinksError(router, err, cb) + return + } + } + } + + if linkDetails == nil { + if len(resp.Errors) > 0 { + err = errors.New(strings.Join(resp.Errors, ",")) + network.reportRouterLinksError(router, err, cb) + return + } + network.reportRouterLinksError(router, errors.New("no link details returned from router"), cb) + return + } + + linkMap := network.Link.GetLinkMap() + + result := &mgmt_pb.RouterLinkDetails{ + RouterId: router.Id, + RouterName: router.Name, + ValidateSuccess: true, + } + + for _, link := range linkDetails.Links { + detail := &mgmt_pb.RouterLinkDetail{ + LinkId: link.Id, + RouterState: mgmt_pb.LinkState_LinkEstablished, + DestRouterId: link.Dest, + Dialed: link.Dialed, + } + detail.DestConnected = network.ConnectedRouter(link.Dest) + if _, found := linkMap[link.Id]; found { + detail.CtrlState = mgmt_pb.LinkState_LinkEstablished + detail.IsValid = detail.DestConnected + } else { + detail.CtrlState = mgmt_pb.LinkState_LinkUnknown + detail.IsValid = !detail.DestConnected + } + delete(linkMap, link.Id) + result.LinkDetails = append(result.LinkDetails, detail) + } + + for _, link := range linkMap { + related := false + dest := "" + if link.Src.Id == router.Id { + related = true + dest = link.DstId + } else if link.DstId == router.Id { + related = true + dest = link.Src.Id + } + + if related { + detail := &mgmt_pb.RouterLinkDetail{ + LinkId: link.Id, + CtrlState: mgmt_pb.LinkState_LinkEstablished, + DestConnected: network.ConnectedRouter(dest), + RouterState: mgmt_pb.LinkState_LinkUnknown, + IsValid: false, + DestRouterId: dest, + Dialed: link.Src.Id == router.Id, + } + result.LinkDetails = append(result.LinkDetails, detail) + } + } + + cb(result) +} + +func (network *Network) reportRouterLinksError(router *model.Router, err error, cb LinkValidationCallback) { + result := &mgmt_pb.RouterLinkDetails{ + RouterId: router.Id, + RouterName: router.Name, + ValidateSuccess: false, + Message: err.Error(), + } + cb(result) +} + +func minCost(q map[*model.Router]bool, dist map[*model.Router]int64) *model.Router { + if dist == nil || len(dist) < 1 { + return nil + } + + min := int64(math.MaxInt64) + var selected *model.Router + for r := range q { + d := dist[r] + if d <= min { + selected = r + min = d + } + } + return selected +} + type Cache interface { RemoveFromCache(id string) } -func newPathAndCost(path []*Router, cost int64) *PathAndCost { +func newPathAndCost(path []*model.Router, cost int64) *PathAndCost { if cost > (1 << 20) { cost = 1 << 20 } @@ -1487,7 +1578,7 @@ func newPathAndCost(path []*Router, cost int64) *PathAndCost { } type PathAndCost struct { - path []*Router + path []*model.Router cost uint32 } diff --git a/controller/network/path.go b/controller/network/network_path.go similarity index 50% rename from controller/network/path.go rename to controller/network/network_path.go index 52444337f..b5f87be4f 100644 --- a/controller/network/path.go +++ b/controller/network/network_path.go @@ -1,138 +1,58 @@ -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - package network import ( "fmt" - "math" - "time" - "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/xt" "github.com/pkg/errors" + "math" + "time" ) -type Path struct { - Nodes []*Router - Links []*Link - IngressId string - EgressId string - InitiatorLocalAddr string - InitiatorRemoteAddr string - TerminatorLocalAddr string - TerminatorRemoteAddr string -} - -func (self *Path) cost(minRouterCost uint16) int64 { - var cost int64 - for _, l := range self.Links { - cost += l.GetCost() - } - for _, r := range self.Nodes { - cost += int64(maxUint16(r.Cost, minRouterCost)) - } - return cost -} - -func (self *Path) String() string { - if len(self.Nodes) < 1 { - return "{}" - } - if len(self.Links) != len(self.Nodes)-1 { - return "{malformed}" - } - out := fmt.Sprintf("[r/%s]", self.Nodes[0].Id) - for i := 0; i < len(self.Links); i++ { - out += fmt.Sprintf("->[l/%s]", self.Links[i].Id) - out += fmt.Sprintf("->[r/%s]", self.Nodes[i+1].Id) - } - return out -} - -func (self *Path) EqualPath(other *Path) bool { - if len(self.Nodes) != len(other.Nodes) { - return false - } - if len(self.Links) != len(other.Links) { - return false - } - for i := 0; i < len(self.Nodes); i++ { - if self.Nodes[i] != other.Nodes[i] { - return false - } - } - for i := 0; i < len(self.Links); i++ { - if self.Links[i] != other.Links[i] { - return false - } - } - return true -} - -func (self *Path) EgressRouter() *Router { - if len(self.Nodes) > 0 { - return self.Nodes[len(self.Nodes)-1] - } - return nil -} - -func (self *Path) CreateRouteMessages(attempt uint32, circuitId string, terminator xt.Terminator, deadline time.Time) []*ctrl_pb.Route { +func (network *Network) CreateRouteMessages(path *model.Path, attempt uint32, circuitId string, terminator xt.Terminator, deadline time.Time) []*ctrl_pb.Route { var routeMessages []*ctrl_pb.Route remainingTime := time.Until(deadline) - if len(self.Links) == 0 { + if len(path.Links) == 0 { // single router path routeMessage := &ctrl_pb.Route{CircuitId: circuitId, Attempt: attempt, Timeout: uint64(remainingTime)} routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ - SrcAddress: self.IngressId, - DstAddress: self.EgressId, + SrcAddress: path.IngressId, + DstAddress: path.EgressId, DstType: ctrl_pb.DestType_End, }) routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ - SrcAddress: self.EgressId, - DstAddress: self.IngressId, + SrcAddress: path.EgressId, + DstAddress: path.IngressId, DstType: ctrl_pb.DestType_Start, }) routeMessage.Egress = &ctrl_pb.Route_Egress{ Binding: terminator.GetBinding(), - Address: self.EgressId, + Address: path.EgressId, Destination: terminator.GetAddress(), } routeMessages = append(routeMessages, routeMessage) } - for i, link := range self.Links { + for i, link := range path.Links { if i == 0 { // ingress routeMessage := &ctrl_pb.Route{CircuitId: circuitId, Attempt: attempt, Timeout: uint64(remainingTime)} routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ - SrcAddress: self.IngressId, + SrcAddress: path.IngressId, DstAddress: link.Id, DstType: ctrl_pb.DestType_Link, }) routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ SrcAddress: link.Id, - DstAddress: self.IngressId, + DstAddress: path.IngressId, DstType: ctrl_pb.DestType_Start, }) routeMessages = append(routeMessages, routeMessage) } - if i >= 0 && i < len(self.Links)-1 { + if i >= 0 && i < len(path.Links)-1 { // transit - nextLink := self.Links[i+1] + nextLink := path.Links[i+1] routeMessage := &ctrl_pb.Route{CircuitId: circuitId, Attempt: attempt, Timeout: uint64(remainingTime)} routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ SrcAddress: link.Id, @@ -146,24 +66,24 @@ func (self *Path) CreateRouteMessages(attempt uint32, circuitId string, terminat }) routeMessages = append(routeMessages, routeMessage) } - if i == len(self.Links)-1 { + if i == len(path.Links)-1 { // egress routeMessage := &ctrl_pb.Route{CircuitId: circuitId, Attempt: attempt, Timeout: uint64(remainingTime)} if attempt != SmartRerouteAttempt { routeMessage.Egress = &ctrl_pb.Route_Egress{ Binding: terminator.GetBinding(), - Address: self.EgressId, + Address: path.EgressId, Destination: terminator.GetAddress(), } } routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ - SrcAddress: self.EgressId, + SrcAddress: path.EgressId, DstAddress: link.Id, DstType: ctrl_pb.DestType_Link, }) routeMessage.Forwards = append(routeMessage.Forwards, &ctrl_pb.Route_Forward{ SrcAddress: link.Id, - DstAddress: self.EgressId, + DstAddress: path.EgressId, DstType: ctrl_pb.DestType_End, }) routeMessages = append(routeMessages, routeMessage) @@ -172,31 +92,65 @@ func (self *Path) CreateRouteMessages(attempt uint32, circuitId string, terminat return routeMessages } -func (self *Path) usesLink(l *Link) bool { - if self.Links != nil { - for _, o := range self.Links { - if o == l { - return true - } - } +func (network *Network) CreatePathWithNodes(nodes []*model.Router) (*model.Path, CircuitError) { + ingressId, err := network.sequence.NextHash() + if err != nil { + return nil, newCircuitErrWrap(CircuitFailureIdGenerationError, err) + } + + egressId, err := network.sequence.NextHash() + if err != nil { + return nil, newCircuitErrWrap(CircuitFailureIdGenerationError, err) + } + + path := &model.Path{ + Nodes: nodes, + IngressId: ingressId, + EgressId: egressId, + } + if err := network.setLinks(path); err != nil { + return nil, newCircuitErrWrap(CircuitFailurePathMissingLink, err) + } + return path, nil +} + +func (network *Network) UpdatePath(path *model.Path) (*model.Path, error) { + srcR := path.Nodes[0] + dstR := path.Nodes[len(path.Nodes)-1] + nodes, _, err := network.shortestPath(srcR, dstR) + if err != nil { + return nil, err + } + + path2 := &model.Path{ + Nodes: nodes, + IngressId: path.IngressId, + EgressId: path.EgressId, + InitiatorLocalAddr: path.InitiatorLocalAddr, + InitiatorRemoteAddr: path.InitiatorRemoteAddr, + TerminatorLocalAddr: path.TerminatorLocalAddr, + TerminatorRemoteAddr: path.TerminatorRemoteAddr, + } + if err := network.setLinks(path2); err != nil { + return nil, err } - return false + return path2, nil } -func (network *Network) shortestPath(srcR *Router, dstR *Router) ([]*Router, int64, error) { +func (network *Network) shortestPath(srcR *model.Router, dstR *model.Router) ([]*model.Router, int64, error) { if srcR == nil || dstR == nil { return nil, 0, errors.New("not routable (!srcR||!dstR)") } if srcR == dstR { - return []*Router{srcR}, 0, nil + return []*model.Router{srcR}, 0, nil } - dist := make(map[*Router]int64) - prev := make(map[*Router]*Router) - unvisited := make(map[*Router]bool) + dist := make(map[*model.Router]int64) + prev := make(map[*model.Router]*model.Router) + unvisited := make(map[*model.Router]bool) - for _, r := range network.Routers.allConnected() { + for _, r := range network.Router.AllConnected() { dist[r] = math.MaxInt32 unvisited[r] = true } @@ -211,13 +165,13 @@ func (network *Network) shortestPath(srcR *Router, dstR *Router) ([]*Router, int } delete(unvisited, u) - neighbors := network.linkController.connectedNeighborsOfRouter(u) + neighbors := network.Link.ConnectedNeighborsOfRouter(u) for _, r := range neighbors { if _, found := unvisited[r]; found { var cost int64 = math.MaxInt32 + 1 - if l, found := network.linkController.leastExpensiveLink(r, u); found { + if l, found := network.Link.LeastExpensiveLink(r, u); found { if !r.NoTraversal || r == srcR || r == dstR { - cost = l.GetCost() + int64(maxUint16(r.Cost, minRouterCost)) + cost = l.GetCost() + int64(max(r.Cost, minRouterCost)) } } @@ -237,10 +191,10 @@ func (network *Network) shortestPath(srcR *Router, dstR *Router) ([]*Router, int * r2 = 0 <- nil */ - routerPath := make([]*Router, 0) + routerPath := make([]*model.Router, 0) p := prev[dstR] for p != nil { - routerPath = append([]*Router{p}, routerPath...) + routerPath = append([]*model.Router{p}, routerPath...) p = prev[p] } routerPath = append(routerPath, dstR) @@ -254,27 +208,3 @@ func (network *Network) shortestPath(srcR *Router, dstR *Router) ([]*Router, int return routerPath, dist[dstR], nil } - -func minCost(q map[*Router]bool, dist map[*Router]int64) *Router { - if dist == nil || len(dist) < 1 { - return nil - } - - min := int64(math.MaxInt64) - var selected *Router - for r := range q { - d := dist[r] - if d <= min { - selected = r - min = d - } - } - return selected -} - -func maxUint16(v1, v2 uint16) uint16 { - if v1 > v2 { - return v1 - } - return v2 -} diff --git a/controller/network/network_test.go b/controller/network/network_test.go index d17c714f0..3502dc47e 100644 --- a/controller/network/network_test.go +++ b/controller/network/network_test.go @@ -1,7 +1,9 @@ package network import ( + "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/event" + "github.com/openziti/ziti/controller/model" "github.com/pkg/errors" "runtime" "testing" @@ -23,7 +25,7 @@ import ( type testConfig struct { ctx *db.TestContext - options *Options + options *config.NetworkConfig metricsRegistry metrics.Registry versionProvider versions.VersionProvider closeNotify chan struct{} @@ -33,13 +35,13 @@ func (self *testConfig) RenderJsonConfig() (string, error) { panic(errors.New("not implemented")) } -func newTestConfig(ctx *db.TestContext) *testConfig { - options := DefaultOptions() +func newTestConfig(ctx *model.TestContext) *testConfig { + options := config.DefaultNetworkConfig() options.MinRouterCost = 0 closeNotify := make(chan struct{}) return &testConfig{ - ctx: ctx, + ctx: ctx.TestContext, options: options, metricsRegistry: metrics.NewRegistry("test", nil), versionProvider: NewVersionProviderTest(), @@ -59,7 +61,7 @@ func (self *testConfig) GetMetricsRegistry() metrics.Registry { return self.metricsRegistry } -func (self *testConfig) GetOptions() *Options { +func (self *testConfig) GetOptions() *config.NetworkConfig { return self.options } @@ -113,22 +115,22 @@ func TestNetwork_parseServiceAndIdentity(t *testing.T) { } func TestCreateCircuit(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) assert.Nil(t, err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) assert.Nil(t, err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 0, false) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 0, false) - svc := &Service{ + svc := &model.Service{ BaseEntity: models.BaseEntity{Id: "svc"}, Name: "svc", TerminatorStrategy: "smartrouting", @@ -145,7 +147,7 @@ func TestCreateCircuit(t *testing.T) { assert.Error(t, cerr) assert.Equal(t, CircuitFailureNoTerminators, cerr.Cause()) - svc.Terminators = []*Terminator{ + svc.Terminators = []*model.Terminator{ { BaseEntity: models.BaseEntity{Id: "t0"}, Service: "svc", @@ -161,7 +163,7 @@ func TestCreateCircuit(t *testing.T) { assert.Error(t, cerr) assert.Equal(t, CircuitFailureNoOnlineTerminators, cerr.Cause()) - network.Routers.markConnected(r0) + network.Router.MarkConnected(r0) _, _, _, _, cerr = network.selectPath(params, svc, "", lc) assert.NoError(t, cerr) @@ -207,7 +209,7 @@ func NewVersionProviderTest() versions.VersionProvider { return &VersionProviderTest{} } -func newCircuitParams(service *Service, router *Router) CreateCircuitParams { +func newCircuitParams(service *model.Service, router *model.Router) model.CreateCircuitParams { return testCreateCircuitParams{ svc: service, router: router, @@ -215,15 +217,15 @@ func newCircuitParams(service *Service, router *Router) CreateCircuitParams { } type testCreateCircuitParams struct { - svc *Service - router *Router + svc *model.Service + router *model.Router } func (t testCreateCircuitParams) GetServiceId() string { return t.svc.Id } -func (t testCreateCircuitParams) GetSourceRouter() *Router { +func (t testCreateCircuitParams) GetSourceRouter() *model.Router { return t.router } diff --git a/controller/network/path_test.go b/controller/network/path_test.go index 144e1b419..4451672a2 100644 --- a/controller/network/path_test.go +++ b/controller/network/path_test.go @@ -17,42 +17,40 @@ package network import ( + config2 "github.com/openziti/ziti/controller/config" + "github.com/openziti/ziti/controller/model" "testing" "time" "github.com/stretchr/testify/require" - "github.com/openziti/channel/v2" - "github.com/openziti/transport/v2" "github.com/openziti/transport/v2/tcp" - "github.com/openziti/ziti/controller/db" - "github.com/openziti/ziti/controller/models" "github.com/stretchr/testify/assert" ) func TestSimplePath2(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) assert.Nil(t, err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) assert.Nil(t, err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r1) - l0 := newTestLink("l0", r0, r1) - l0.SetState(Connected) - network.linkController.add(l0) + l0 := model.NewTestLink("l0", r0, r1) + l0.SetState(model.Connected) + network.Link.Add(l0) path, err := network.CreatePath(r0, r1) assert.NotNil(t, path) @@ -64,8 +62,8 @@ func TestSimplePath2(t *testing.T) { assert.Equal(t, l0, path.Links[0]) assert.Equal(t, r1, path.EgressRouter()) - terminator := &Terminator{Address: addr, Binding: "transport"} - routeMessages := path.CreateRouteMessages(0, "s0", terminator, time.Now().Add(DefaultOptionsRouteTimeout)) + terminator := &model.Terminator{Address: addr, Binding: "transport"} + routeMessages := network.CreateRouteMessages(path, 0, "s0", terminator, time.Now().Add(config2.DefaultOptionsRouteTimeout)) assert.NotNil(t, routeMessages) assert.Equal(t, 2, len(routeMessages)) @@ -92,35 +90,35 @@ func TestSimplePath2(t *testing.T) { } func TestTransitPath2(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) assert.Nil(t, err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) assert.Nil(t, err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r2) - l0 := newTestLink("l0", r0, r1) - l0.SetState(Connected) - network.linkController.add(l0) + l0 := model.NewTestLink("l0", r0, r1) + l0.SetState(model.Connected) + network.Link.Add(l0) - l1 := newTestLink("l1", r1, r2) - l1.SetState(Connected) - network.linkController.add(l1) + l1 := model.NewTestLink("l1", r1, r2) + l1.SetState(model.Connected) + network.Link.Add(l1) path, err := network.CreatePath(r0, r2) assert.NotNil(t, path) @@ -134,8 +132,8 @@ func TestTransitPath2(t *testing.T) { assert.Equal(t, l1, path.Links[1]) assert.Equal(t, r2, path.EgressRouter()) - terminator := &Terminator{Address: addr, Binding: "transport"} - routeMessages := path.CreateRouteMessages(0, "s0", terminator, time.Now().Add(DefaultOptionsRouteTimeout)) + terminator := &model.Terminator{Address: addr, Binding: "transport"} + routeMessages := network.CreateRouteMessages(path, 0, "s0", terminator, time.Now().Add(config2.DefaultOptionsRouteTimeout)) assert.NotNil(t, routeMessages) assert.Equal(t, 3, len(routeMessages)) @@ -171,23 +169,8 @@ func TestTransitPath2(t *testing.T) { assert.Equal(t, path.EgressId, rm2.Forwards[1].DstAddress) } -func newRouterForTest(id string, fingerprint string, advLstnr transport.Address, ctrl channel.Channel, cost uint16, noTraversal bool) *Router { - r := &Router{ - BaseEntity: models.BaseEntity{Id: id}, - Name: id, - Fingerprint: &fingerprint, - Control: ctrl, - Cost: cost, - NoTraversal: noTraversal, - } - if advLstnr != nil { - r.AddLinkListener(advLstnr.String(), advLstnr.Type(), []string{"Cost Tag"}, []string{"default"}) - } - return r -} - func TestShortestPath(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -195,52 +178,52 @@ func TestShortestPath(t *testing.T) { config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, false) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, false) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 3, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 3, false) + network.Router.MarkConnected(r2) - r3 := newRouterForTest("r3", "", transportAddr, nil, 4, false) - network.Routers.markConnected(r3) + r3 := model.NewRouterForTest("r3", "", transportAddr, nil, 4, false) + network.Router.MarkConnected(r3) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l1", r0, r2) + link = model.NewTestLink("l1", r0, r2) link.SetStaticCost(5) link.SetDstLatency(15 * 1_000_000) link.SetSrcLatency(16 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l2", r1, r3) + link = model.NewTestLink("l2", r1, r3) link.SetStaticCost(9) link.SetDstLatency(20 * 1_000_000) link.SetSrcLatency(21 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l3", r2, r3) + link = model.NewTestLink("l3", r2, r3) link.SetStaticCost(13) link.SetDstLatency(25 * 1_000_000) link.SetSrcLatency(26 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r3) req.NoError(err) @@ -255,7 +238,7 @@ func TestShortestPath(t *testing.T) { } func TestShortestPathWithUntraversableRouter(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -263,52 +246,52 @@ func TestShortestPathWithUntraversableRouter(t *testing.T) { config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, false) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, false) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, true) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, true) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 3, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 3, false) + network.Router.MarkConnected(r2) - r3 := newRouterForTest("r3", "", transportAddr, nil, 4, false) - network.Routers.markConnected(r3) + r3 := model.NewRouterForTest("r3", "", transportAddr, nil, 4, false) + network.Router.MarkConnected(r3) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l1", r0, r2) + link = model.NewTestLink("l1", r0, r2) link.SetStaticCost(5) link.SetDstLatency(15 * 1_000_000) link.SetSrcLatency(16 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l2", r1, r3) + link = model.NewTestLink("l2", r1, r3) link.SetStaticCost(9) link.SetDstLatency(20 * 1_000_000) link.SetSrcLatency(21 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l3", r2, r3) + link = model.NewTestLink("l3", r2, r3) link.SetStaticCost(13) link.SetDstLatency(25 * 1_000_000) link.SetSrcLatency(26 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r3) req.NoError(err) @@ -323,7 +306,7 @@ func TestShortestPathWithUntraversableRouter(t *testing.T) { } func TestShortestPathWithOnlyUntraversableRouter(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -331,25 +314,25 @@ func TestShortestPathWithOnlyUntraversableRouter(t *testing.T) { config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, false) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, false) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, true) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, true) + network.Router.MarkConnected(r1) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r1) req.NoError(err) @@ -363,7 +346,7 @@ func TestShortestPathWithOnlyUntraversableRouter(t *testing.T) { } func TestShortestPathWithUntraversableEdgeRouters(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -371,25 +354,25 @@ func TestShortestPathWithUntraversableEdgeRouters(t *testing.T) { config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, true) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, true) + network.Router.MarkConnected(r1) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(3) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r1) req.NoError(err) @@ -403,7 +386,7 @@ func TestShortestPathWithUntraversableEdgeRouters(t *testing.T) { } func TestShortestPathWithUntraversableEdgeRoutersAndTraversableMiddle(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -411,35 +394,35 @@ func TestShortestPathWithUntraversableEdgeRoutersAndTraversableMiddle(t *testing config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 3, true) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 3, true) + network.Router.MarkConnected(r2) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l1", r1, r2) + link = model.NewTestLink("l1", r1, r2) link.SetStaticCost(3) link.SetDstLatency(12 * 1_000_000) link.SetSrcLatency(15 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r2) req.NoError(err) @@ -455,7 +438,7 @@ func TestShortestPathWithUntraversableEdgeRoutersAndTraversableMiddle(t *testing } func TestShortestPathWithUntraversableEdgeRoutersAndUntraversableMiddle(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := assert.New(t) @@ -463,35 +446,35 @@ func TestShortestPathWithUntraversableEdgeRoutersAndUntraversableMiddle(t *testi config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 1, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 1, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 2, true) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 2, true) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 2, true) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 2, true) + network.Router.MarkConnected(r2) - link := newTestLink("l0", r0, r1) + link := model.NewTestLink("l0", r0, r1) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) - link = newTestLink("l2", r1, r2) + link = model.NewTestLink("l2", r1, r2) link.SetStaticCost(2) link.SetDstLatency(10 * 1_000_000) link.SetSrcLatency(11 * 1_000_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) path, cost, err := network.shortestPath(r0, r2) req.Error(err) @@ -502,7 +485,7 @@ func TestShortestPathWithUntraversableEdgeRoutersAndUntraversableMiddle(t *testi } func TestRouterCost(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := require.New(t) @@ -510,24 +493,24 @@ func TestRouterCost(t *testing.T) { config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 10, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 10, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 100, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 100, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 200, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 200, false) + network.Router.MarkConnected(r2) - r3 := newRouterForTest("r3", "", transportAddr, nil, 20, true) - network.Routers.markConnected(r3) + r3 := model.NewRouterForTest("r3", "", transportAddr, nil, 20, true) + network.Router.MarkConnected(r3) newPathTestLink(network, "l0", r0, r1) newPathTestLink(network, "l1", r0, r2) @@ -558,7 +541,7 @@ func TestRouterCost(t *testing.T) { } func TestMinRouterCost(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() req := require.New(t) @@ -567,24 +550,24 @@ func TestMinRouterCost(t *testing.T) { defer close(config.closeNotify) config.options.MinRouterCost = 10 - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) req.NoError(err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) req.NoError(err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 0, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 0, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 7, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 7, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 200, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 200, false) + network.Router.MarkConnected(r2) - r3 := newRouterForTest("r3", "", transportAddr, nil, 20, true) - network.Routers.markConnected(r3) + r3 := model.NewRouterForTest("r3", "", transportAddr, nil, 20, true) + network.Router.MarkConnected(r3) newPathTestLink(network, "l0", r0, r1) newPathTestLink(network, "l1", r0, r2) @@ -614,12 +597,12 @@ func TestMinRouterCost(t *testing.T) { req.Equal(int64(222), cost) } -func newPathTestLink(network *Network, id string, srcR, destR *Router) *Link { - l := newTestLink(id, srcR, destR) +func newPathTestLink(network *Network, id string, srcR, destR *model.Router) *model.Link { + l := model.NewTestLink(id, srcR, destR) l.SrcLatency = 0 l.DstLatency = 0 - l.recalculateCost() - l.SetState(Connected) - network.linkController.add(l) + l.RecalculateCost() + l.SetState(model.Connected) + network.Link.Add(l) return l } diff --git a/controller/network/route_perf_test.go b/controller/network/route_perf_test.go index 7f24b2eeb..b97b710e4 100644 --- a/controller/network/route_perf_test.go +++ b/controller/network/route_perf_test.go @@ -18,29 +18,29 @@ package network import ( "fmt" + "github.com/openziti/ziti/controller/model" "math/rand" "testing" "github.com/michaelquigley/pfxlog" - "github.com/openziti/ziti/controller/db" "github.com/sirupsen/logrus" ) func TestShortestPathAgainstEstablished(t *testing.T) { pfxlog.GlobalInit(logrus.WarnLevel, pfxlog.DefaultOptions()) - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) - var routers []*Router + var routers []*model.Router for i := 0; i < 50; i++ { router := entityHelper.addTestRouter() @@ -56,14 +56,14 @@ func TestShortestPathAgainstEstablished(t *testing.T) { return int64(v % 1000) } - addLink := func(srcRouter, dstRouter *Router) { + addLink := func(srcRouter, dstRouter *model.Router) { if srcRouter != dstRouter { - link := newTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) + link := model.NewTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) link.SetStaticCost(int32(nextCost())) link.SetDstLatency(nextCost() * 100_000) link.SetSrcLatency(nextCost() * 100_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) linkIdx++ } } @@ -153,18 +153,18 @@ func BenchmarkShortestPathPerfWithRouterChanges(b *testing.B) { b.StopTimer() pfxlog.GlobalInit(logrus.WarnLevel, pfxlog.DefaultOptions()) - ctx := db.NewTestContext(b) + ctx := model.NewTestContext(b) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) - var routers []*Router + var routers []*model.Router for i := 0; i < 50; i++ { router := entityHelper.addTestRouter() @@ -180,14 +180,14 @@ func BenchmarkShortestPathPerfWithRouterChanges(b *testing.B) { return int64(v % 1000) } - addLink := func(srcRouter, dstRouter *Router) { + addLink := func(srcRouter, dstRouter *model.Router) { if srcRouter != dstRouter { - link := newTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) + link := model.NewTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) link.SetStaticCost(int32(nextCost())) link.SetDstLatency(nextCost() * 100_000) link.SetSrcLatency(nextCost() * 100_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) linkIdx++ } } @@ -245,18 +245,18 @@ func BenchmarkShortestPathPerf(b *testing.B) { b.StopTimer() pfxlog.GlobalInit(logrus.WarnLevel, pfxlog.DefaultOptions()) - ctx := db.NewTestContext(b) + ctx := model.NewTestContext(b) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) - var routers []*Router + var routers []*model.Router for i := 0; i < 400; i++ { router := entityHelper.addTestRouter() @@ -272,14 +272,14 @@ func BenchmarkShortestPathPerf(b *testing.B) { return int64(v % 1000) } - addLink := func(srcRouter, dstRouter *Router) { + addLink := func(srcRouter, dstRouter *model.Router) { if srcRouter != dstRouter { - link := newTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) + link := model.NewTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) link.SetStaticCost(int32(nextCost())) link.SetDstLatency(nextCost() * 100_000) link.SetSrcLatency(nextCost() * 100_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) linkIdx++ } } @@ -317,18 +317,18 @@ func BenchmarkMoreRealisticShortestPathPerf(b *testing.B) { //b.StopTimer() pfxlog.GlobalInit(logrus.WarnLevel, pfxlog.DefaultOptions()) - ctx := db.NewTestContext(b) + ctx := model.NewTestContext(b) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) - var routers []*Router + var routers []*model.Router for i := 0; i < 200; i++ { router := entityHelper.addTestRouter() @@ -344,21 +344,21 @@ func BenchmarkMoreRealisticShortestPathPerf(b *testing.B) { return int64(v % 1000) } - addLink := func(srcRouter, dstRouter *Router) { + addLink := func(srcRouter, dstRouter *model.Router) { if srcRouter != dstRouter { - link := newTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) + link := model.NewTestLink(fmt.Sprintf("link-%04d", linkIdx), srcRouter, dstRouter) link.SetStaticCost(int32(nextCost())) link.SetDstLatency(nextCost() * 100_000) link.SetSrcLatency(nextCost() * 100_000) - link.SetState(Connected) - network.linkController.add(link) + link.SetState(model.Connected) + network.Link.Add(link) linkIdx++ } } // make half the routers private routers - var privateRouters []*Router - var publicRouters []*Router + var privateRouters []*model.Router + var publicRouters []*model.Router for idx, router := range routers { if idx <= len(routers)/2 { privateRouters = append(privateRouters, router) diff --git a/controller/network/router_messaging.go b/controller/network/router_messaging.go index f6c362940..88e2f40fb 100644 --- a/controller/network/router_messaging.go +++ b/controller/network/router_messaging.go @@ -27,6 +27,7 @@ import ( "github.com/openziti/ziti/common/pb/ctrl_pb" "github.com/openziti/ziti/controller/change" "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/xt" "sync/atomic" "time" @@ -57,22 +58,24 @@ type routerEvent interface { handle(c *RouterMessaging) } -func NewRouterMessaging(managers *Managers, routerCommPool goroutines.Pool) *RouterMessaging { +func NewRouterMessaging(env model.Env, routerCommPool goroutines.Pool) *RouterMessaging { result := &RouterMessaging{ - managers: managers, + env: env, + managers: env.GetManagers(), eventsC: make(chan routerEvent, 16), routerUpdates: map[string]*routerUpdates{}, terminatorValidations: map[string]*terminatorValidations{}, routerCommPool: routerCommPool, } - managers.stores.Terminator.AddEntityEventListenerF(result.TerminatorCreated, boltz.EntityCreated) + env.GetManagers().Terminator.GetStore().AddEntityEventListenerF(result.TerminatorCreated, boltz.EntityCreated) return result } type RouterMessaging struct { - managers *Managers + env model.Env + managers *model.Managers eventsC chan routerEvent routerUpdates map[string]*routerUpdates terminatorValidations map[string]*terminatorValidations @@ -88,11 +91,11 @@ func (self *RouterMessaging) getNextMarker() uint64 { return result } -func (self *RouterMessaging) RouterConnected(r *Router) { +func (self *RouterMessaging) RouterConnected(r *model.Router) { self.routerChanged(r.Id, true) } -func (self *RouterMessaging) RouterDisconnected(r *Router) { +func (self *RouterMessaging) RouterDisconnected(r *model.Router) { self.routerChanged(r.Id, false) } @@ -116,7 +119,7 @@ func (self *RouterMessaging) routerChanged(routerId string, connected bool) { func (self *RouterMessaging) queueEvent(evt routerEvent) { select { case self.eventsC <- evt: - case <-self.managers.network.GetCloseNotify(): + case <-self.env.GetCloseNotifyChannel(): } } @@ -129,7 +132,7 @@ func (self *RouterMessaging) run() { case evt := <-self.eventsC: evt.handle(self) case <-ticker.C: - case <-self.managers.network.GetCloseNotify(): + case <-self.env.GetCloseNotifyChannel(): return } @@ -180,7 +183,7 @@ func (self *RouterMessaging) syncStates() { notifyRouterId := k updates := v changes := &ctrl_pb.PeerStateChanges{} - notifyRouter := self.managers.Routers.getConnected(notifyRouterId) + notifyRouter := self.managers.Router.GetConnected(notifyRouterId) if notifyRouter == nil { // if the router disconnected, we're going to sync everything anyway, so clear anything pending here delete(self.routerUpdates, k) @@ -198,7 +201,7 @@ func (self *RouterMessaging) syncStates() { } for routerId := range updates.changedRouters { - router := self.managers.Routers.getConnected(routerId) + router := self.managers.Router.GetConnected(routerId) if router != nil { changes.Changes = append(changes.Changes, &ctrl_pb.PeerStateChange{ Id: routerId, @@ -207,7 +210,7 @@ func (self *RouterMessaging) syncStates() { Listeners: router.Listeners, }) } else { - exists, err := self.managers.Routers.Exists(routerId) + exists, err := self.managers.Router.Exists(routerId) if exists && err == nil { changes.Changes = append(changes.Changes, &ctrl_pb.PeerStateChange{ Id: routerId, @@ -267,7 +270,7 @@ func (self *RouterMessaging) sendTerminatorValidationRequests() { } func (self *RouterMessaging) sendTerminatorValidationRequest(routerId string, updates *terminatorValidations) { - notifyRouter := self.managers.Routers.getConnected(routerId) + notifyRouter := self.managers.Router.GetConnected(routerId) if notifyRouter == nil { // if the router disconnected, we're going to sync everything anyway, so clear anything pending here delete(self.terminatorValidations, routerId) @@ -343,7 +346,7 @@ func (self *RouterMessaging) sendTerminatorValidationRequest(routerId string, up } } -func (self *RouterMessaging) generateMockResponseForV1(r *Router, validations *terminatorValidations) { +func (self *RouterMessaging) generateMockResponseForV1(r *model.Router, validations *terminatorValidations) { handler := &terminatorValidationRespReceived{ router: r, changeCtx: change.New(), // won't be used since we're marking things valid @@ -362,7 +365,7 @@ func (self *RouterMessaging) generateMockResponseForV1(r *Router, validations *t self.queueEvent(handler) } -func (self *RouterMessaging) NewValidationResponseHandler(n *Network, r *Router) channel.ReceiveHandlerF { +func (self *RouterMessaging) NewValidationResponseHandler(n *Network, r *model.Router) channel.ReceiveHandlerF { return func(m *channel.Message, ch channel.Channel) { log := pfxlog.Logger().WithField("routerId", r.Id) resp := &ctrl_pb.ValidateTerminatorsV2Response{} @@ -383,7 +386,7 @@ func (self *RouterMessaging) NewValidationResponseHandler(n *Network, r *Router) } } -func (self *RouterMessaging) ValidateRouterTerminators(terminators []*Terminator) { +func (self *RouterMessaging) ValidateRouterTerminators(terminators []*model.Terminator) { self.queueEvent(&validateTerminators{ terminators: terminators, }) @@ -420,7 +423,7 @@ func (self *routerChangedEvent) handle(c *RouterMessaging) { WithField("connected", self.connected). Info("calculating router updates for router") - routers := c.managers.Routers.allConnected() + routers := c.managers.Router.AllConnected() var sourceRouterState *routerUpdates for _, router := range routers { @@ -468,7 +471,7 @@ func (self *routerPeerChangesSendDone) handle(c *RouterMessaging) { } type validateTerminators struct { - terminators []*Terminator + terminators []*model.Terminator } func (self *validateTerminators) handle(c *RouterMessaging) { @@ -491,7 +494,7 @@ func (self *validateTerminators) handle(c *RouterMessaging) { } type terminatorValidationRespReceived struct { - router *Router + router *model.Router changeCtx *change.Context resp *ctrl_pb.ValidateTerminatorsV2Response success bool @@ -511,7 +514,7 @@ func (self *terminatorValidationRespReceived) DeleteInvalid(n *Network) { } if len(toDelete) > 0 { - if err := n.Managers.Terminators.DeleteBatch(toDelete, self.changeCtx); err != nil { + if err := n.Managers.Terminator.DeleteBatch(toDelete, self.changeCtx); err != nil { for _, terminatorId := range toDelete { log.WithField("terminatorId", terminatorId). WithError(err). @@ -548,7 +551,7 @@ func (self *routerMessagingInspectEvent) handle(c *RouterMessaging) { result := &inspect.RouterMessagingState{} getRouterName := func(routerId string) string { - if router, _ := c.managers.Routers.Read(routerId); router != nil { + if router, _ := c.managers.Router.Read(routerId); router != nil { return router.Name } return "" diff --git a/controller/network/routesender.go b/controller/network/routesender.go index 84d34e9dc..da6425b38 100644 --- a/controller/network/routesender.go +++ b/controller/network/routesender.go @@ -19,14 +19,15 @@ package network import ( "fmt" "github.com/openziti/ziti/controller/change" + "github.com/openziti/ziti/controller/model" "time" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2/protobufs" - "github.com/openziti/ziti/controller/xt" "github.com/openziti/ziti/common/ctrl_msg" "github.com/openziti/ziti/common/logcontext" "github.com/openziti/ziti/common/pb/ctrl_pb" + "github.com/openziti/ziti/controller/xt" cmap "github.com/orcaman/concurrent-map/v2" "github.com/sirupsen/logrus" ) @@ -63,10 +64,10 @@ type routeSender struct { in chan *RouteStatus attendance map[string]bool serviceCounters ServiceCounters - terminators *TerminatorManager + terminators *model.TerminatorManager } -func newRouteSender(circuitId string, timeout time.Duration, serviceCounters ServiceCounters, terminators *TerminatorManager) *routeSender { +func newRouteSender(circuitId string, timeout time.Duration, serviceCounters ServiceCounters, terminators *model.TerminatorManager) *routeSender { return &routeSender{ circuitId: circuitId, timeout: timeout, @@ -77,7 +78,7 @@ func newRouteSender(circuitId string, timeout time.Duration, serviceCounters Ser } } -func (self *routeSender) route(attempt uint32, path *Path, routeMsgs []*ctrl_pb.Route, strategy xt.Strategy, terminator xt.Terminator, ctx logcontext.Context) (peerData xt.PeerData, cleanups map[string]struct{}, err CircuitError) { +func (self *routeSender) route(attempt uint32, path *model.Path, routeMsgs []*ctrl_pb.Route, strategy xt.Strategy, terminator xt.Terminator, ctx logcontext.Context) (peerData xt.PeerData, cleanups map[string]struct{}, err CircuitError) { logger := pfxlog.ChannelLogger(logcontext.EstablishPath).Wire(ctx) // send route messages @@ -127,7 +128,7 @@ attendance: return peerData, nil, nil } -func (self *routeSender) handleRouteSend(attempt uint32, path *Path, strategy xt.Strategy, status *RouteStatus, terminator xt.Terminator, logger *pfxlog.Builder) (peerData xt.PeerData, cleanups map[string]struct{}, err CircuitError) { +func (self *routeSender) handleRouteSend(attempt uint32, path *model.Path, strategy xt.Strategy, status *RouteStatus, terminator xt.Terminator, logger *pfxlog.Builder) (peerData xt.PeerData, cleanups map[string]struct{}, err CircuitError) { if status.Success == (status.ErrorCode != nil) { logger.Errorf("route status success and error code differ. Success: %v ErrorCode: %v", status.Success, status.ErrorCode) } @@ -180,7 +181,7 @@ func (self *routeSender) handleRouteSend(attempt uint32, path *Path, strategy xt failureCause = CircuitFailureRouterErrInvalidTerminator } else { self.serviceCounters.ServiceMisconfiguredTerminator(terminator.GetServiceId(), terminator.GetId()) - self.terminators.handlePrecedenceChange(terminator.GetId(), xt.Precedences.Failed) + self.terminators.HandlePrecedenceChange(terminator.GetId(), xt.Precedences.Failed) failureCause = CircuitFailureRouterErrMisconfiguredTerminator } case ctrl_msg.ErrorTypeDialTimedOut: @@ -208,7 +209,7 @@ func (self *routeSender) handleRouteSend(attempt uint32, path *Path, strategy xt return nil, nil, nil } -func (self *routeSender) sendRoute(r *Router, routeMsg *ctrl_pb.Route, ctx logcontext.Context) { +func (self *routeSender) sendRoute(r *model.Router, routeMsg *ctrl_pb.Route, ctx logcontext.Context) { logger := pfxlog.ChannelLogger(logcontext.EstablishPath).Wire(ctx).WithField("routerId", r.Id) envelope := protobufs.MarshalTyped(routeMsg).WithTimeout(3 * time.Second) @@ -219,7 +220,7 @@ func (self *routeSender) sendRoute(r *Router, routeMsg *ctrl_pb.Route, ctx logco } } -func (self *routeSender) cleanups(path *Path) map[string]struct{} { +func (self *routeSender) cleanups(path *model.Path) map[string]struct{} { cleanups := make(map[string]struct{}) for _, r := range path.Nodes { success, found := self.attendance[r.Id] @@ -231,7 +232,7 @@ func (self *routeSender) cleanups(path *Path) map[string]struct{} { } type RouteStatus struct { - Router *Router + Router *model.Router CircuitId string Attempt uint32 Success bool diff --git a/controller/network/routesender_test.go b/controller/network/routesender_test.go index a7b1dc1ec..6782bdace 100644 --- a/controller/network/routesender_test.go +++ b/controller/network/routesender_test.go @@ -1,23 +1,23 @@ package network import ( + "github.com/openziti/ziti/controller/model" "testing" "github.com/michaelquigley/pfxlog" - "github.com/openziti/ziti/controller/db" + "github.com/openziti/ziti/common/ctrl_msg" "github.com/openziti/ziti/controller/xt" "github.com/openziti/ziti/controller/xt_smartrouting" - "github.com/openziti/ziti/common/ctrl_msg" ) func TestRouteSender_DestroysTerminatorWhenInvalidOnHandleRouteSendAndWeControl(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) @@ -25,8 +25,8 @@ func TestRouteSender_DestroysTerminatorWhenInvalidOnHandleRouteSendAndWeControl( router1 := entityHelper.addTestRouter() router2 := entityHelper.addTestRouter() - path := &Path{ - Nodes: []*Router{router1, router2}, + path := &model.Path{ + Nodes: []*model.Router{router1, router2}, } svc := entityHelper.addTestService("svc") @@ -40,7 +40,7 @@ func TestRouteSender_DestroysTerminatorWhenInvalidOnHandleRouteSendAndWeControl( rs := routeSender{ serviceCounters: network, - terminators: network.Terminators, + terminators: network.Terminator, attendance: make(map[string]bool), } @@ -58,19 +58,19 @@ func TestRouteSender_DestroysTerminatorWhenInvalidOnHandleRouteSendAndWeControl( ctx.Nil(peerData) ctx.Empty(cleanup) - newTerm, err := network.Terminators.Read(term.Id) + newTerm, err := network.Terminator.Read(term.Id) ctx.Error(err) ctx.Nil(newTerm) } func TestRouteSender_SetPrecidenceToNilTerminatorWhenInvalidOnHandleRouteSendAndWeDontControl(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) ctx.NoError(err) entityHelper := newTestEntityHelper(ctx, network) @@ -78,8 +78,8 @@ func TestRouteSender_SetPrecidenceToNilTerminatorWhenInvalidOnHandleRouteSendAnd router1 := entityHelper.addTestRouter() router2 := entityHelper.addTestRouter() - path := &Path{ - Nodes: []*Router{router1, router2}, + path := &model.Path{ + Nodes: []*model.Router{router1, router2}, } svc := entityHelper.addTestService("svc") @@ -93,7 +93,7 @@ func TestRouteSender_SetPrecidenceToNilTerminatorWhenInvalidOnHandleRouteSendAnd rs := routeSender{ serviceCounters: network, - terminators: network.Terminators, + terminators: network.Terminator, attendance: make(map[string]bool), } @@ -111,7 +111,7 @@ func TestRouteSender_SetPrecidenceToNilTerminatorWhenInvalidOnHandleRouteSendAnd ctx.Nil(peerData) ctx.Empty(cleanup) - newTerm, err := network.Terminators.Read(term.Id) + newTerm, err := network.Terminator.Read(term.Id) ctx.NoError(err) ctx.NotNil(newTerm) diff --git a/controller/network/smart.go b/controller/network/smart.go index 5e31ef9e3..2838111e2 100644 --- a/controller/network/smart.go +++ b/controller/network/smart.go @@ -17,6 +17,8 @@ package network import ( + "github.com/openziti/ziti/controller/config" + "github.com/openziti/ziti/controller/model" log "github.com/sirupsen/logrus" "sort" "time" @@ -31,8 +33,8 @@ func (network *Network) smart() { candidates := network.getRerouteCandidates() for _, update := range candidates { - if retry := network.smartReroute(update.circuit, update.path, time.Now().Add(DefaultOptionsRouteTimeout)); retry { - go network.rerouteCircuitWithTries(update.circuit, DefaultOptionsCreateCircuitRetries) + if retry := network.smartReroute(update.circuit, update.path, time.Now().Add(config.DefaultOptionsRouteTimeout)); retry { + go network.rerouteCircuitWithTries(update.circuit, config.DefaultOptionsCreateCircuitRetries) } } } @@ -53,7 +55,7 @@ func (network *Network) getRerouteCandidates() []*newCircuitPath { circuitLatencies := make(map[string]int64) var orderedCircuits []string for _, s := range circuits { - circuitLatencies[s.Id] = s.cost(minRouterCost) + circuitLatencies[s.Id] = s.Path.Cost(minRouterCost) orderedCircuits = append(orderedCircuits, s.Id) } @@ -81,8 +83,8 @@ func (network *Network) getRerouteCandidates() []*newCircuitPath { if circuit, found := network.GetCircuit(sId); found { if updatedPath, err := network.UpdatePath(circuit.Path); err == nil { pathChanged := !updatedPath.EqualPath(circuit.Path) - oldCost := circuit.Path.cost(minRouterCost) - newCost := updatedPath.cost(minRouterCost) + oldCost := circuit.Path.Cost(minRouterCost) + newCost := updatedPath.Cost(minRouterCost) costDelta := oldCost - newCost log.Tracef("old cost: %v, new cost: %v, delta: %v", oldCost, newCost, costDelta) if count < ceiling && pathChanged && costDelta >= int64(network.options.Smart.MinCostDelta) { @@ -101,6 +103,6 @@ func (network *Network) getRerouteCandidates() []*newCircuitPath { } type newCircuitPath struct { - circuit *Circuit - path *Path + circuit *model.Circuit + path *model.Path } diff --git a/controller/network/smart_test.go b/controller/network/smart_test.go index 3e4922a7c..0b0a0f0b7 100644 --- a/controller/network/smart_test.go +++ b/controller/network/smart_test.go @@ -1,20 +1,20 @@ package network import ( + "github.com/openziti/ziti/controller/model" "testing" "time" "github.com/google/uuid" "github.com/openziti/transport/v2/tcp" "github.com/openziti/ziti/common/logcontext" - "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/models" "github.com/openziti/ziti/controller/xt" "github.com/stretchr/testify/assert" ) func TestSmartRerouteMinCostDelta(t *testing.T) { - ctx := db.NewTestContext(t) + ctx := model.NewTestContext(t) defer ctx.Cleanup() config := newTestConfig(ctx) @@ -22,31 +22,31 @@ func TestSmartRerouteMinCostDelta(t *testing.T) { config.options.Smart.MinCostDelta = 15 defer close(config.closeNotify) - network, err := NewNetwork(config) + network, err := NewNetwork(config, ctx) assert.Nil(t, err) addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) assert.Nil(t, err) - r0 := newRouterForTest("r0", "", transportAddr, nil, 0, true) - network.Routers.markConnected(r0) + r0 := model.NewRouterForTest("r0", "", transportAddr, nil, 0, true) + network.Router.MarkConnected(r0) - r1 := newRouterForTest("r1", "", transportAddr, nil, 15, false) - network.Routers.markConnected(r1) + r1 := model.NewRouterForTest("r1", "", transportAddr, nil, 15, false) + network.Router.MarkConnected(r1) - r2 := newRouterForTest("r2", "", transportAddr, nil, 0, false) - network.Routers.markConnected(r2) + r2 := model.NewRouterForTest("r2", "", transportAddr, nil, 0, false) + network.Router.MarkConnected(r2) - r3 := newRouterForTest("r3", "", transportAddr, nil, 0, true) - network.Routers.markConnected(r3) + r3 := model.NewRouterForTest("r3", "", transportAddr, nil, 0, true) + network.Router.MarkConnected(r3) newPathTestLink(network, "l0", r0, r1) link1 := newPathTestLink(network, "l1", r0, r2) newPathTestLink(network, "l2", r1, r3) newPathTestLink(network, "l3", r2, r3) - svc := &Service{ + svc := &model.Service{ BaseEntity: models.BaseEntity{Id: "svc"}, Name: "svc", TerminatorStrategy: "smartrouting", @@ -54,7 +54,7 @@ func TestSmartRerouteMinCostDelta(t *testing.T) { lc := logcontext.NewContext() - svc.Terminators = []*Terminator{ + svc.Terminators = []*model.Terminator{ { BaseEntity: models.BaseEntity{Id: "t0"}, Service: "svc", @@ -77,16 +77,16 @@ func TestSmartRerouteMinCostDelta(t *testing.T) { assert.Equal(t, "l1", path.Links[0].Id) assert.Equal(t, "l3", path.Links[1].Id) - assert.Equal(t, int64(32), path.cost(network.options.MinRouterCost)) + assert.Equal(t, int64(32), path.Cost(network.options.MinRouterCost)) - circuit := &Circuit{ + circuit := &model.Circuit{ Id: uuid.NewString(), ServiceId: svc.Id, Path: path, Terminator: terminator, CreatedAt: time.Now(), } - network.circuitController.add(circuit) + network.Circuit.Add(circuit) // r0 - r1 - r3 = 10 + 1 + 10 + 1 + 10 = 32 // r0 - r2 - r3 = 10 + 1 + 15 + 1 + 10 = 37 diff --git a/controller/network/util_test.go b/controller/network/util_test.go index d0c63c28d..dc6538464 100644 --- a/controller/network/util_test.go +++ b/controller/network/util_test.go @@ -19,6 +19,7 @@ package network import ( "fmt" "github.com/openziti/ziti/controller/change" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" "github.com/openziti/ziti/controller/xt_smartrouting" @@ -27,13 +28,13 @@ import ( "github.com/openziti/ziti/controller/db" ) -func newTestEntityHelper(ctx *db.TestContext, network *Network) *testEntityHelper { +func newTestEntityHelper(ctx *model.TestContext, network *Network) *testEntityHelper { addr := "tcp:0.0.0.0:0" transportAddr, err := tcp.AddressParser{}.Parse(addr) ctx.NoError(err) return &testEntityHelper{ - ctx: ctx, + ctx: ctx.TestContext, network: network, transportAddr: transportAddr, } @@ -48,17 +49,17 @@ type testEntityHelper struct { transportAddr transport.Address } -func (self *testEntityHelper) addTestRouter() *Router { - router := newRouterForTest(fmt.Sprintf("router-%03d", self.routerIdx), "", self.transportAddr, nil, 0, false) - self.network.Routers.markConnected(router) - self.ctx.NoError(self.network.Routers.Create(router, change.New())) +func (self *testEntityHelper) addTestRouter() *model.Router { + router := model.NewRouterForTest(fmt.Sprintf("router-%03d", self.routerIdx), "", self.transportAddr, nil, 0, false) + self.network.Router.MarkConnected(router) + self.ctx.NoError(self.network.Router.Create(router, change.New())) self.routerIdx++ return router } -func (self *testEntityHelper) addTestTerminator(serviceName string, routerName string, instanceId string, isSystem bool) *Terminator { +func (self *testEntityHelper) addTestTerminator(serviceName string, routerName string, instanceId string, isSystem bool) *model.Terminator { id := fmt.Sprintf("terminator-#%d", self.terminatorIdx) - term := &Terminator{ + term := &model.Terminator{ BaseEntity: models.BaseEntity{ Id: id, IsSystem: isSystem, @@ -68,20 +69,20 @@ func (self *testEntityHelper) addTestTerminator(serviceName string, routerName s InstanceId: instanceId, Address: "ToDo", } - self.ctx.NoError(self.network.Terminators.Create(term, change.New())) + self.ctx.NoError(self.network.Terminator.Create(term, change.New())) self.terminatorIdx++ return term } -func (self *testEntityHelper) addTestService(serviceName string) *Service { +func (self *testEntityHelper) addTestService(serviceName string) *model.Service { id := fmt.Sprintf("service-#%d", self.serviceIdx) - svc := &Service{ + svc := &model.Service{ BaseEntity: models.BaseEntity{Id: id}, Name: serviceName, TerminatorStrategy: xt_smartrouting.Name, } self.serviceIdx++ - self.ctx.NoError(self.network.Services.Create(svc, change.New())) + self.ctx.NoError(self.network.Service.Create(svc, change.New())) return svc } diff --git a/controller/oidc_auth/storage.go b/controller/oidc_auth/storage.go index 0a577e22c..abf2df4d1 100644 --- a/controller/oidc_auth/storage.go +++ b/controller/oidc_auth/storage.go @@ -147,7 +147,7 @@ func NewStorage(kid string, publicKey crypto.PublicKey, privateKey crypto.Privat // start will run Clean every 10 seconds func (s *HybridStorage) start() { s.startOnce.Do(func() { - closeNotify := s.env.GetHostController().GetCloseNotifyChannel() + closeNotify := s.env.GetCloseNotifyChannel() ticker := time.NewTicker(10 * time.Second) go func() { for { @@ -624,7 +624,7 @@ func (s *HybridStorage) SignatureAlgorithms(context.Context) ([]jose.SignatureAl // KeySet implements the op.Storage interface func (s *HybridStorage) KeySet(_ context.Context) ([]op.Key, error) { - signers := s.env.GetHostController().GetPeerSigners() + signers := s.env.GetPeerSigners() for _, cert := range signers { kid := fmt.Sprintf("%s", sha1.Sum(cert.Raw)) diff --git a/controller/raft/raft.go b/controller/raft/raft.go index fb57ea0ce..3381d602d 100644 --- a/controller/raft/raft.go +++ b/controller/raft/raft.go @@ -20,13 +20,13 @@ import ( "crypto/x509" "encoding/json" "fmt" - "github.com/hashicorp/go-hclog" "github.com/mitchellh/mapstructure" "github.com/openziti/foundation/v2/concurrenz" "github.com/openziti/foundation/v2/rate" "github.com/openziti/foundation/v2/versions" "github.com/openziti/transport/v2" "github.com/openziti/ziti/common/pb/cmd_pb" + "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/event" "github.com/openziti/ziti/controller/peermsg" "os" @@ -51,92 +51,6 @@ import ( "github.com/sirupsen/logrus" ) -type Config struct { - Recover bool - DataDir string - MinClusterSize uint32 - AdvertiseAddress transport.Address - BootstrapMembers []string - CommandHandlerOptions struct { - MaxQueueSize uint16 - } - - SnapshotInterval *time.Duration - SnapshotThreshold *uint32 - TrailingLogs *uint32 - MaxAppendEntries *uint32 - - ElectionTimeout *time.Duration - CommitTimeout *time.Duration - HeartbeatTimeout *time.Duration - LeaderLeaseTimeout *time.Duration - - LogLevel *string - Logger hclog.Logger -} - -func (self *Config) Configure(conf *raft.Config) { - if self.SnapshotThreshold != nil { - conf.SnapshotThreshold = uint64(*self.SnapshotThreshold) - } - - if self.SnapshotInterval != nil { - conf.SnapshotInterval = *self.SnapshotInterval - } - - if self.TrailingLogs != nil { - conf.TrailingLogs = uint64(*self.TrailingLogs) - } - - if self.MaxAppendEntries != nil { - conf.MaxAppendEntries = int(*self.MaxAppendEntries) - } - - if self.CommitTimeout != nil { - conf.CommitTimeout = *self.CommitTimeout - } - - if self.ElectionTimeout != nil { - conf.ElectionTimeout = *self.ElectionTimeout - } - - if self.HeartbeatTimeout != nil { - conf.HeartbeatTimeout = *self.HeartbeatTimeout - } - - if self.LeaderLeaseTimeout != nil { - conf.LeaderLeaseTimeout = *self.LeaderLeaseTimeout - } - - if self.LogLevel != nil { - conf.LogLevel = *self.LogLevel - } - - conf.Logger = self.Logger -} - -func (self *Config) ConfigureReloadable(conf *raft.ReloadableConfig) { - if self.SnapshotThreshold != nil { - conf.SnapshotThreshold = uint64(*self.SnapshotThreshold) - } - - if self.SnapshotInterval != nil { - conf.SnapshotInterval = *self.SnapshotInterval - } - - if self.TrailingLogs != nil { - conf.TrailingLogs = uint64(*self.TrailingLogs) - } - - if self.ElectionTimeout != nil { - conf.ElectionTimeout = *self.ElectionTimeout - } - - if self.HeartbeatTimeout != nil { - conf.HeartbeatTimeout = *self.HeartbeatTimeout - } -} - type RouterDispatchCallback func(*raft.Configuration) error type ClusterEvent uint32 @@ -195,7 +109,7 @@ type Env interface { GetId() *identity.TokenId GetVersionProvider() versions.VersionProvider GetCommandRateLimiterConfig() command.RateLimiterConfig - GetRaftConfig() *Config + GetRaftConfig() *config.RaftConfig GetMetricsRegistry() metrics.Registry GetEventDispatcher() event.Dispatcher GetCloseNotify() <-chan struct{} @@ -219,7 +133,7 @@ func NewController(env Env, migrationMgr MigrationManager) *Controller { // Controller manages RAFT related state and operations type Controller struct { env Env - Config *Config + Config *config.RaftConfig Mesh mesh.Mesh Raft *raft.Raft Fsm *BoltDbFsm @@ -577,7 +491,7 @@ func (self *Controller) Init() error { conf := raft.DefaultConfig() conf.LocalID = raft.ServerID(self.env.GetId().Token) conf.NoSnapshotRestoreOnStart = true - raftConfig.Configure(conf) + self.Configure(raftConfig, conf) // Create the log store and stable store. raftBoltFile := path.Join(raftConfig.DataDir, "raft.db") @@ -633,7 +547,7 @@ func (self *Controller) Init() error { } rc := r.ReloadableConfig() - raftConfig.ConfigureReloadable(&rc) + self.ConfigureReloadable(raftConfig, &rc) if err = r.ReloadConfig(rc); err != nil { return errors.Wrap(err, "error reloading raft configuration") } @@ -645,6 +559,68 @@ func (self *Controller) Init() error { return nil } +func (self *Controller) Configure(ctrlConfig *config.RaftConfig, conf *raft.Config) { + if ctrlConfig.SnapshotThreshold != nil { + conf.SnapshotThreshold = uint64(*ctrlConfig.SnapshotThreshold) + } + + if ctrlConfig.SnapshotInterval != nil { + conf.SnapshotInterval = *ctrlConfig.SnapshotInterval + } + + if ctrlConfig.TrailingLogs != nil { + conf.TrailingLogs = uint64(*ctrlConfig.TrailingLogs) + } + + if ctrlConfig.MaxAppendEntries != nil { + conf.MaxAppendEntries = int(*ctrlConfig.MaxAppendEntries) + } + + if ctrlConfig.CommitTimeout != nil { + conf.CommitTimeout = *ctrlConfig.CommitTimeout + } + + if ctrlConfig.ElectionTimeout != nil { + conf.ElectionTimeout = *ctrlConfig.ElectionTimeout + } + + if ctrlConfig.HeartbeatTimeout != nil { + conf.HeartbeatTimeout = *ctrlConfig.HeartbeatTimeout + } + + if ctrlConfig.LeaderLeaseTimeout != nil { + conf.LeaderLeaseTimeout = *ctrlConfig.LeaderLeaseTimeout + } + + if ctrlConfig.LogLevel != nil { + conf.LogLevel = *ctrlConfig.LogLevel + } + + conf.Logger = ctrlConfig.Logger +} + +func (self *Controller) ConfigureReloadable(ctrlConfig *config.RaftConfig, conf *raft.ReloadableConfig) { + if ctrlConfig.SnapshotThreshold != nil { + conf.SnapshotThreshold = uint64(*ctrlConfig.SnapshotThreshold) + } + + if ctrlConfig.SnapshotInterval != nil { + conf.SnapshotInterval = *ctrlConfig.SnapshotInterval + } + + if ctrlConfig.TrailingLogs != nil { + conf.TrailingLogs = uint64(*ctrlConfig.TrailingLogs) + } + + if ctrlConfig.ElectionTimeout != nil { + conf.ElectionTimeout = *ctrlConfig.ElectionTimeout + } + + if ctrlConfig.HeartbeatTimeout != nil { + conf.HeartbeatTimeout = *ctrlConfig.HeartbeatTimeout + } +} + func (self *Controller) validateCert() { var certs []*x509.Certificate for _, cert := range self.env.GetId().ServerCert() { diff --git a/controller/server/client-api.go b/controller/server/client-api.go index f29f8f449..4288b7ad7 100644 --- a/controller/server/client-api.go +++ b/controller/server/client-api.go @@ -43,6 +43,7 @@ type ClientApiFactory struct { func (factory ClientApiFactory) Validate(config *xweb.InstanceConfig) error { clientApiFound := false + edgeConfig := factory.appEnv.GetConfig().Edge for _, webListener := range config.ServerConfigs { for _, api := range webListener.APIs { @@ -53,12 +54,12 @@ func (factory ClientApiFactory) Validate(config *xweb.InstanceConfig) error { return errors.Errorf("could not read xweb web listener [%s]'s CA file [%s] to retrieve CA PEMs: %v", webListener.Name, webListener.Identity.GetConfig().CA, err) } - factory.appEnv.Config.AddCaPems(caBytes) + edgeConfig.AddCaPems(caBytes) } if !clientApiFound && api.Binding() == controller.ClientApiBinding { for _, bindPoint := range webListener.BindPoints { - if bindPoint.Address == factory.appEnv.Config.Api.Address { + if bindPoint.Address == edgeConfig.Api.Address { factory.appEnv.SetServerCert(webListener.Identity.ServerCert()[0]) clientApiFound = true break @@ -68,10 +69,10 @@ func (factory ClientApiFactory) Validate(config *xweb.InstanceConfig) error { } } - factory.appEnv.Config.RefreshCas() + edgeConfig.RefreshCas() if !clientApiFound { - return errors.Errorf("could not find [edge.api.address] value [%s] as a bind point any instance of ApiConfig [%s]", factory.appEnv.Config.Api.Address, controller.ClientApiBinding) + return errors.Errorf("could not find [edge.api.address] value [%s] as a bind point any instance of ApiConfig [%s]", edgeConfig.Api.Address, controller.ClientApiBinding) } return nil diff --git a/controller/server/controller.go b/controller/server/controller.go index 2bf3ffc04..7a565113b 100644 --- a/controller/server/controller.go +++ b/controller/server/controller.go @@ -21,7 +21,6 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" "github.com/openziti/storage/boltz" - "github.com/openziti/ziti/common/config" "github.com/openziti/ziti/common/pb/edge_ctrl_pb" runner2 "github.com/openziti/ziti/common/runner" "github.com/openziti/ziti/controller/api_impl" @@ -39,7 +38,7 @@ import ( ) type Controller struct { - config *edgeconfig.Config + config *edgeconfig.EdgeConfig AppEnv *env.AppEnv xmgmt *submgmt xctrl *subctrl @@ -58,20 +57,16 @@ const ( ZitiInstanceId = "ziti-instance-id" ) -func NewController(cfg config.Configurable, host env.HostController) (*Controller, error) { - c := &Controller{} - - if err := cfg.Configure(c); err != nil { - return nil, fmt.Errorf("failed to load configuration: %s", err) +func NewController(host env.HostController) (*Controller, error) { + c := &Controller{ + config: host.GetConfig().Edge, + AppEnv: host.GetEnv(), } if !c.IsEnabled() { return c, nil } - c.AppEnv = env.NewAppEnv(c.config, host) - - c.AppEnv.TraceManager = env.NewTraceManager(host.GetCloseNotifyChannel()) c.AppEnv.HostController.GetNetwork().AddCapability("ziti.edge") pfxlog.Logger().Infof("edge controller instance id: %s", c.AppEnv.InstanceId) @@ -187,7 +182,7 @@ func (c *Controller) LoadConfig(cfgmap map[interface{}]interface{}) error { return nil } - parsedConfig, err := edgeconfig.LoadFromMap(cfgmap) + parsedConfig, err := edgeconfig.LoadEdgeConfigFromMap(cfgmap) if err != nil { return fmt.Errorf("error loading edge controller configuration: %s", err.Error()) } @@ -204,7 +199,7 @@ func (c *Controller) Enabled() bool { func (c *Controller) initializeAuthModules() { c.initModulesOnce.Do(func() { c.AppEnv.AuthRegistry.Add(model.NewAuthModuleUpdb(c.AppEnv)) - c.AppEnv.AuthRegistry.Add(model.NewAuthModuleCert(c.AppEnv, c.AppEnv.GetConfig().CaPems())) + c.AppEnv.AuthRegistry.Add(model.NewAuthModuleCert(c.AppEnv, c.AppEnv.GetConfig().Edge.CaPems())) c.AppEnv.AuthRegistry.Add(model.NewAuthModuleExtJwt(c.AppEnv)) c.AppEnv.EnrollRegistry.Add(model.NewEnrollModuleCa(c.AppEnv)) diff --git a/controller/settings.go b/controller/settings.go index fe84954c2..38d3af55b 100644 --- a/controller/settings.go +++ b/controller/settings.go @@ -4,9 +4,10 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" "github.com/openziti/channel/v2/protobufs" - "github.com/openziti/ziti/controller/network" - "github.com/openziti/ziti/controller/raft" "github.com/openziti/ziti/common/pb/ctrl_pb" + config2 "github.com/openziti/ziti/controller/config" + "github.com/openziti/ziti/controller/model" + "github.com/openziti/ziti/controller/raft" "google.golang.org/protobuf/proto" ) @@ -14,15 +15,15 @@ import ( // Settings are a map of int32 -> []byte data. The type should be used to determine how the setting's []byte // array is consumed. type OnConnectSettingsHandler struct { - config *Config + config *config2.Config settings map[int32][]byte } -func (o *OnConnectSettingsHandler) RouterDisconnected(r *network.Router) { +func (o *OnConnectSettingsHandler) RouterDisconnected(r *model.Router) { //do nothing, satisfy interface } -func (o OnConnectSettingsHandler) RouterConnected(r *network.Router) { +func (o OnConnectSettingsHandler) RouterConnected(r *model.Router) { if len(o.settings) > 0 { settingsMsg := &ctrl_pb.Settings{ Data: map[int32][]byte{}, @@ -62,11 +63,11 @@ func NewOnConnectCtrlAddressesUpdateHandler(ctrlAddress string, raft *raft.Contr } } -func (o *OnConnectCtrlAddressesUpdateHandler) RouterDisconnected(r *network.Router) { +func (o *OnConnectCtrlAddressesUpdateHandler) RouterDisconnected(r *model.Router) { //do nothing, satisfy interface } -func (o OnConnectCtrlAddressesUpdateHandler) RouterConnected(r *network.Router) { +func (o OnConnectCtrlAddressesUpdateHandler) RouterConnected(r *model.Router) { log := pfxlog.Logger().WithFields(map[string]interface{}{ "routerId": r.Id, "channel": r.Control.LogicalName(), diff --git a/controller/subcmd/init.go b/controller/subcmd/init.go index 214a59379..74ff7ba64 100644 --- a/controller/subcmd/init.go +++ b/controller/subcmd/init.go @@ -19,10 +19,11 @@ package subcmd import ( "errors" "github.com/michaelquigley/pfxlog" - "github.com/openziti/ziti/controller/server" - "github.com/openziti/ziti/controller" "github.com/openziti/foundation/v2/term" "github.com/openziti/foundation/v2/versions" + "github.com/openziti/ziti/controller" + "github.com/openziti/ziti/controller/config" + "github.com/openziti/ziti/controller/server" "github.com/spf13/cobra" "strconv" "strings" @@ -143,7 +144,7 @@ func validatePasswordLength(password string) error { } func configureController(configPath string, versionProvider versions.VersionProvider) *server.Controller { - config, err := controller.LoadConfig(configPath) + config, err := config.LoadConfig(configPath) if err != nil { pfxlog.Logger().WithError(err).Fatalf("could not read configuration file [%s]", configPath) @@ -154,7 +155,7 @@ func configureController(configPath string, versionProvider versions.VersionProv panic(err) } - edgeController, err := server.NewController(config, fabricController) + edgeController, err := server.NewController(fabricController) if err != nil { panic(err) diff --git a/controller/sync_strats/marshal.go b/controller/sync_strats/marshal.go index c52b1776b..7d6857a68 100644 --- a/controller/sync_strats/marshal.go +++ b/controller/sync_strats/marshal.go @@ -25,7 +25,7 @@ import ( func apiSessionToProto(ae *env.AppEnv, token, identityId, apiSessionId string) (*edge_ctrl_pb.ApiSession, error) { var result *edge_ctrl_pb.ApiSession - err := ae.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := ae.GetDb().View(func(tx *bbolt.Tx) error { var err error result, err = apiSessionToProtoWithTx(tx, ae, token, identityId, apiSessionId) return err diff --git a/controller/sync_strats/rtx.go b/controller/sync_strats/rtx.go index 9074bf887..b8b2b318f 100644 --- a/controller/sync_strats/rtx.go +++ b/controller/sync_strats/rtx.go @@ -22,7 +22,6 @@ import ( "github.com/openziti/ziti/common/eid" "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/network" cmap "github.com/orcaman/concurrent-map/v2" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -36,7 +35,7 @@ type RouterSender struct { env.RouterState Id string EdgeRouter *model.EdgeRouter - Router *network.Router + Router *model.Router send chan *channel.Message closeNotify chan struct{} running atomic.Bool @@ -47,7 +46,7 @@ type RouterSender struct { sync.Mutex } -func newRouterSender(edgeRouter *model.EdgeRouter, router *network.Router, sendBufferSize int) *RouterSender { +func newRouterSender(edgeRouter *model.EdgeRouter, router *model.Router, sendBufferSize int) *RouterSender { rtx := &RouterSender{ Id: eid.New(), EdgeRouter: edgeRouter, @@ -139,7 +138,7 @@ func (m *routerTxMap) GetState(id string) env.RouterStateValues { return rtx.GetState() } -func (m *routerTxMap) Remove(r *network.Router) { +func (m *routerTxMap) Remove(r *model.Router) { var rtx *RouterSender m.internalMap.RemoveCb(r.Id, func(key string, v *RouterSender, exists bool) bool { if !exists { diff --git a/controller/sync_strats/sync_instant.go b/controller/sync_strats/sync_instant.go index 9b06503be..da78fc1e8 100644 --- a/controller/sync_strats/sync_instant.go +++ b/controller/sync_strats/sync_instant.go @@ -39,7 +39,6 @@ import ( "github.com/openziti/ziti/controller/env" "github.com/openziti/ziti/controller/handler_edge_ctrl" "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/network" cmap "github.com/orcaman/concurrent-map/v2" "github.com/pkg/errors" "go.etcd.io/bbolt" @@ -212,7 +211,7 @@ func (strategy *InstantStrategy) Initialize(logSize uint64, bufferSize uint) err updateHandler: strategy.ControllerUpdate, } - strategy.ae.GetDbProvider().GetDb().AddTxCompleteListener(strategy.completeChangeSet) + strategy.ae.GetDb().AddTxCompleteListener(strategy.completeChangeSet) return nil } @@ -291,7 +290,7 @@ func (strategy *InstantStrategy) Stop() { } } -func (strategy *InstantStrategy) RouterConnected(edgeRouter *model.EdgeRouter, router *network.Router) { +func (strategy *InstantStrategy) RouterConnected(edgeRouter *model.EdgeRouter, router *model.Router) { log := pfxlog.Logger().WithField("sync_strategy", strategy.Type()). WithField("routerId", router.Id). WithField("routerName", router.Name). @@ -322,7 +321,7 @@ func (strategy *InstantStrategy) RouterConnected(edgeRouter *model.EdgeRouter, r strategy.routerConnectedQueue <- rtx } -func (strategy *InstantStrategy) RouterDisconnected(router *network.Router) { +func (strategy *InstantStrategy) RouterDisconnected(router *model.Router) { log := pfxlog.Logger().WithField("sync_strategy", strategy.Type()). WithField("routerId", router.Id). WithField("routerName", router.Name). @@ -649,7 +648,7 @@ func (strategy *InstantStrategy) synchronize(rtx *RouterSender) { logger.Info("started synchronizing edge router") chunkSize := 100 - err := strategy.ae.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := strategy.ae.GetDb().View(func(tx *bbolt.Tx) error { var apiSessions []*edge_ctrl_pb.ApiSession state := &InstantSyncState{ @@ -906,7 +905,7 @@ func (strategy *InstantStrategy) BuildPublicKeys(tx *bbolt.Tx) error { strategy.HandlePublicKeyEvent(newEvent, newModel) } - caPEMs := strategy.ae.Config.CaPems() + caPEMs := strategy.ae.GetConfig().Edge.CaPems() caCerts := nfPem.PemBytesToCertificates(caPEMs) for _, caCert := range caCerts { @@ -950,7 +949,7 @@ func (strategy *InstantStrategy) BuildPublicKeys(tx *bbolt.Tx) error { } func (strategy *InstantStrategy) BuildAll() error { - err := strategy.ae.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { + err := strategy.ae.GetDb().View(func(tx *bbolt.Tx) error { if err := strategy.BuildIdentities(tx); err != nil { return err } @@ -1498,7 +1497,7 @@ func (p *NonHaIndexProvider) load() { defer p.lock.Unlock() ctx := boltz.NewMutateContext(context.Background()) - err := p.ae.GetDbProvider().GetDb().Update(ctx, func(ctx boltz.MutateContext) error { + err := p.ae.GetDb().Update(ctx, func(ctx boltz.MutateContext) error { zdb, err := ctx.Tx().CreateBucketIfNotExists([]byte(ZdbKey)) if err != nil { @@ -1535,7 +1534,7 @@ func (p *NonHaIndexProvider) NextIndex(ctx boltz.MutateContext) (uint64, error) } updateCtx := boltz.NewMutateContext(context.Background()) - err := p.ae.GetDbProvider().GetDb().Update(updateCtx, func(updateCtx boltz.MutateContext) error { + err := p.ae.GetDb().Update(updateCtx, func(updateCtx boltz.MutateContext) error { zdb := updateCtx.Tx().Bucket([]byte(ZdbKey)) newIndex := p.index + 1 diff --git a/tests/ca_traffic_test.go b/tests/ca_traffic_test.go index f6e24248f..f62200f24 100644 --- a/tests/ca_traffic_test.go +++ b/tests/ca_traffic_test.go @@ -282,7 +282,7 @@ func Test_CA_Auth_Two_Identities_Diff_Certs(t *testing.T) { ServerCert: "", ServerKey: "", AltServerCerts: nil, - CA: id.StoragePem + ":" + string(ctx.EdgeController.AppEnv.Config.CaPems()), + CA: id.StoragePem + ":" + string(ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems()), }, ConfigTypes: nil, } @@ -318,7 +318,7 @@ func Test_CA_Auth_Two_Identities_Diff_Certs(t *testing.T) { ServerCert: "", ServerKey: "", AltServerCerts: nil, - CA: id.StoragePem + ":" + string(ctx.EdgeController.AppEnv.Config.CaPems()), + CA: id.StoragePem + ":" + string(ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems()), }, ConfigTypes: nil, } diff --git a/tests/context.go b/tests/context.go index 8036e62d9..96f4b8ce4 100644 --- a/tests/context.go +++ b/tests/context.go @@ -17,6 +17,7 @@ package tests import ( + "github.com/openziti/ziti/controller/config" "io" "net" "net/http" @@ -133,10 +134,10 @@ type TestContext struct { edgeRouterEntity *edgeRouter transitRouterEntity *transitRouter - router *router.Router + routers []*router.Router testing *testing.T LogLevel string - ControllerConfig *controller.Config + ControllerConfig *config.Config } var defaultTestContext = &TestContext{ @@ -364,7 +365,7 @@ func (ctx *TestContext) StartServerFor(testDb string, clean bool) { ctx.Req.NoError(err) log.Info("loading config") - config, err := controller.LoadConfig(ControllerConfFile) + config, err := config.LoadConfig(ControllerConfFile) ctx.Req.NoError(err) ctx.ControllerConfig = config @@ -374,7 +375,7 @@ func (ctx *TestContext) StartServerFor(testDb string, clean bool) { ctx.Req.NoError(err) log.Info("creating edge controller") - ctx.EdgeController, err = server.NewController(config, ctx.fabricController) + ctx.EdgeController, err = server.NewController(ctx.fabricController) ctx.Req.NoError(err) ctx.EdgeController.Initialize() @@ -469,30 +470,24 @@ func (ctx *TestContext) createEnrollAndStartTransitRouter() { func (ctx *TestContext) startTransitRouter() { config, err := router.LoadConfig(TransitRouterConfFile) ctx.Req.NoError(err) - ctx.router = router.Create(config, NewVersionProviderTest()) + newRouter := router.Create(config, NewVersionProviderTest()) + ctx.routers = append(ctx.routers, newRouter) - ctx.Req.NoError(ctx.router.Start()) + ctx.Req.NoError(newRouter.Start()) } func (ctx *TestContext) CreateEnrollAndStartTunnelerEdgeRouter(roleAttributes ...string) { - ctx.shutdownRouter() + ctx.shutdownRouters() ctx.createAndEnrollEdgeRouter(true, roleAttributes...) ctx.startEdgeRouter() } func (ctx *TestContext) CreateEnrollAndStartEdgeRouter(roleAttributes ...string) { - ctx.shutdownRouter() + ctx.shutdownRouters() ctx.createAndEnrollEdgeRouter(false, roleAttributes...) ctx.startEdgeRouter() } -func (ctx *TestContext) shutdownRouter() { - if ctx.router != nil { - ctx.Req.NoError(ctx.router.Shutdown()) - ctx.router = nil - } -} - func (ctx *TestContext) startEdgeRouter() { configFile := EdgeRouterConfFile if ctx.edgeRouterEntity.isTunnelerEnabled { @@ -500,18 +495,19 @@ func (ctx *TestContext) startEdgeRouter() { } config, err := router.LoadConfig(configFile) ctx.Req.NoError(err) - ctx.router = router.Create(config, NewVersionProviderTest()) + newRouter := router.Create(config, NewVersionProviderTest()) + ctx.routers = append(ctx.routers, newRouter) - xgressEdgeFactory := xgress_edge.NewFactory(config, ctx.router, ctx.router.GetStateManager()) + xgressEdgeFactory := xgress_edge.NewFactory(config, newRouter, newRouter.GetStateManager()) xgress.GlobalRegistry().Register(common.EdgeBinding, xgressEdgeFactory) - xgressEdgeTunnelFactory := xgress_edge_tunnel.NewFactory(ctx.router, config, ctx.router.GetStateManager()) + xgressEdgeTunnelFactory := xgress_edge_tunnel.NewFactory(newRouter, config, newRouter.GetStateManager()) xgress.GlobalRegistry().Register(common.TunnelBinding, xgressEdgeTunnelFactory) - ctx.Req.NoError(ctx.router.RegisterXrctrl(xgressEdgeFactory)) - ctx.Req.NoError(ctx.router.RegisterXrctrl(xgressEdgeTunnelFactory)) - ctx.Req.NoError(ctx.router.RegisterXrctrl(ctx.router.GetStateManager())) - ctx.Req.NoError(ctx.router.Start()) + ctx.Req.NoError(newRouter.RegisterXrctrl(xgressEdgeFactory)) + ctx.Req.NoError(newRouter.RegisterXrctrl(xgressEdgeTunnelFactory)) + ctx.Req.NoError(newRouter.RegisterXrctrl(newRouter.GetStateManager())) + ctx.Req.NoError(newRouter.Start()) } func (ctx *TestContext) EnrollIdentity(identityId string) *ziti.Config { @@ -565,7 +561,7 @@ func (ctx *TestContext) RequireAdminClientApiLogin() { func (ctx *TestContext) Teardown() { pfxlog.Logger().Info("tearing down test context") - ctx.shutdownRouter() + ctx.shutdownRouters() if ctx.EdgeController != nil { ctx.EdgeController.Shutdown() ctx.EdgeController = nil @@ -880,6 +876,13 @@ func (ctx *TestContext) WrapConn(conn edge.Conn, err error) *TestConn { } } +func (ctx *TestContext) shutdownRouters() { + for _, r := range ctx.routers { + ctx.Req.NoError(r.Shutdown()) + } + ctx.routers = nil +} + type TestConn struct { edge.Conn ctx *TestContext diff --git a/tests/control.go b/tests/control.go index 9ba51b3e4..c4108e16e 100644 --- a/tests/control.go +++ b/tests/control.go @@ -6,12 +6,12 @@ import ( "github.com/openziti/transport/v2" "github.com/openziti/ziti/common/capabilities" "github.com/openziti/ziti/common/pb/ctrl_pb" - "github.com/openziti/ziti/controller" + "github.com/openziti/ziti/controller/config" "math/big" ) func (ctx *FabricTestContext) NewControlChannelListener() channel.UnderlayListener { - config, err := controller.LoadConfig(FabricControllerConfFile) + config, err := config.LoadConfig(FabricControllerConfFile) ctx.Req.NoError(err) ctx.Req.NoError(config.Db.Close()) diff --git a/tests/data_flow_test.go b/tests/data_flow_test.go index 68d4bf87f..749fe7096 100644 --- a/tests/data_flow_test.go +++ b/tests/data_flow_test.go @@ -1,5 +1,4 @@ //go:build dataflow -// +build dataflow /* Copyright NetFoundry Inc. diff --git a/tests/enrollment_router_test.go b/tests/enrollment_router_test.go index 839dce44f..553e44326 100644 --- a/tests/enrollment_router_test.go +++ b/tests/enrollment_router_test.go @@ -73,7 +73,7 @@ func Test_RouterEnrollment(t *testing.T) { ctx.Req.NotNil(cert) ctx.Req.NotNil(pk) - caPems := ctx.EdgeController.AppEnv.Config.CaPems() + caPems := ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems() caCerts, err := parsePEMBundle(caPems) ctx.Req.NoError(err) @@ -97,7 +97,7 @@ func Test_RouterEnrollment(t *testing.T) { ctx.Req.NotNil(cert) ctx.Req.NotNil(pk) - caPems := ctx.EdgeController.AppEnv.Config.CaPems() + caPems := ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems() caCerts, err := parsePEMBundle(caPems) ctx.Req.NoError(err) @@ -324,7 +324,7 @@ func Test_RouterEnrollment(t *testing.T) { ctx.Req.NotNil(cert) ctx.Req.NotNil(pk) - caPems := ctx.EdgeController.AppEnv.Config.CaPems() + caPems := ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems() caCerts, err := parsePEMBundle(caPems) ctx.Req.NoError(err) @@ -348,7 +348,7 @@ func Test_RouterEnrollment(t *testing.T) { ctx.Req.NotNil(cert) ctx.Req.NotNil(pk) - caPems := ctx.EdgeController.AppEnv.Config.CaPems() + caPems := ctx.EdgeController.AppEnv.GetConfig().Edge.CaPems() caCerts, err := parsePEMBundle(caPems) ctx.Req.NoError(err) diff --git a/tests/fabric_context.go b/tests/fabric_context.go index 4327ea001..c45810456 100644 --- a/tests/fabric_context.go +++ b/tests/fabric_context.go @@ -28,6 +28,7 @@ import ( id "github.com/openziti/identity" "github.com/openziti/identity/certtools" "github.com/openziti/ziti/controller/api_impl" + "github.com/openziti/ziti/controller/config" "github.com/openziti/ziti/controller/rest_client" restClientRouter "github.com/openziti/ziti/controller/rest_client/router" "github.com/openziti/ziti/controller/rest_model" @@ -77,7 +78,7 @@ type FabricTestContext struct { routers []*router.Router testing *testing.T LogLevel string - ControllerConfig *controller.Config + ControllerConfig *config.Config } func NewFabricTestContext(t *testing.T) *FabricTestContext { @@ -179,7 +180,7 @@ func (ctx *FabricTestContext) StartServerFor(test string, clean bool) { ctx.Req.NoError(err) log.Info("loading config") - config, err := controller.LoadConfig(FabricControllerConfFile) + config, err := config.LoadConfig(FabricControllerConfFile) ctx.Req.NoError(err) ctx.ControllerConfig = config diff --git a/tests/mfa_ziti_test.go b/tests/mfa_ziti_test.go index c834b9101..4b7f63784 100644 --- a/tests/mfa_ziti_test.go +++ b/tests/mfa_ziti_test.go @@ -25,8 +25,8 @@ import ( "github.com/Jeffail/gabs" "github.com/dgryski/dgoogauth" "github.com/google/uuid" - "github.com/openziti/ziti/controller/apierror" "github.com/openziti/foundation/v2/errorz" + "github.com/openziti/ziti/controller/apierror" "image/png" "net/http" "net/url" @@ -213,7 +213,7 @@ func Test_MFA(t *testing.T) { ctx.Req.NoError(err) ctx.Req.Equal(mfaUrl.Host, "totp") - ctx.Req.Equal(mfaUrl.Path, "/"+ctx.EdgeController.AppEnv.Config.Totp.Hostname+":"+mfaStartedIdentityName) + ctx.Req.Equal(mfaUrl.Path, "/"+ctx.EdgeController.AppEnv.GetConfig().Edge.Totp.Hostname+":"+mfaStartedIdentityName) ctx.Req.Equal(mfaUrl.Scheme, "otpauth") }) }) diff --git a/tests/transit_router_test.go b/tests/transit_router_test.go index 0b4bb5373..947c0ea2e 100644 --- a/tests/transit_router_test.go +++ b/tests/transit_router_test.go @@ -21,8 +21,8 @@ package tests import ( "github.com/openziti/ziti/controller/change" + "github.com/openziti/ziti/controller/model" "github.com/openziti/ziti/controller/models" - "github.com/openziti/ziti/controller/network" "testing" "time" ) @@ -41,8 +41,7 @@ func Test_TransitRouters(t *testing.T) { t.Run("transit routers can be created, enrolled, and started", func(t *testing.T) { ctx.testContextChanged(t) ctx.createEnrollAndStartTransitRouter() - - ctx.Req.NoError(ctx.router.Shutdown()) + ctx.shutdownRouters() }) t.Run("transit routers can be created, enrolled, and listed", func(t *testing.T) { @@ -106,7 +105,7 @@ func Test_TransitRouters(t *testing.T) { ctx.testContextChanged(t) fp := "f6fc1c03175f674f1f0b505a9ff930e5" - fabTxRouter := &network.Router{ + fabTxRouter := &model.Router{ BaseEntity: models.BaseEntity{ Id: "uMvqq", CreatedAt: time.Now(), @@ -116,7 +115,7 @@ func Test_TransitRouters(t *testing.T) { Name: "uMvqq", Fingerprint: &fp, } - err := ctx.fabricController.GetNetwork().Routers.Create(fabTxRouter, change.New()) + err := ctx.fabricController.GetNetwork().Router.Create(fabTxRouter, change.New()) ctx.Req.NoError(err, "could not create router at fabric level") body := ctx.AdminManagementSession.requireQuery("transit-routers") diff --git a/ziti/cmd/create/create_config.go b/ziti/cmd/create/create_config.go index 0a688a24a..aedb2de3e 100644 --- a/ziti/cmd/create/create_config.go +++ b/ziti/cmd/create/create_config.go @@ -26,7 +26,6 @@ import ( "github.com/openziti/channel/v2" foundation "github.com/openziti/transport/v2" fabXweb "github.com/openziti/xweb/v2" - fabCtrl "github.com/openziti/ziti/controller" edge "github.com/openziti/ziti/controller/config" fabForwarder "github.com/openziti/ziti/router/forwarder" "github.com/sirupsen/logrus" @@ -51,7 +50,7 @@ type CreateConfigOptions struct { } type ConfigTemplateValues struct { - ZitiHome string + ZitiHome string HostnameOrNetworkName string Controller ControllerTemplateValues @@ -256,9 +255,9 @@ func (data *ConfigTemplateValues) PopulateConfigValues() { data.Controller.Ctrl.AdvertisedPort = cmdHelper.GetCtrlAdvertisedPort() data.Controller.Database.DatabaseFile = cmdHelper.GetCtrlDatabaseFile() // healthChecks: - data.Controller.HealthChecks.Interval = fabCtrl.DefaultHealthChecksBoltCheckInterval - data.Controller.HealthChecks.Timeout = fabCtrl.DefaultHealthChecksBoltCheckTimeout - data.Controller.HealthChecks.InitialDelay = fabCtrl.DefaultHealthChecksBoltCheckInitialDelay + data.Controller.HealthChecks.Interval = edge.DefaultHealthChecksBoltCheckInterval + data.Controller.HealthChecks.Timeout = edge.DefaultHealthChecksBoltCheckTimeout + data.Controller.HealthChecks.InitialDelay = edge.DefaultHealthChecksBoltCheckInitialDelay // edge: data.Controller.EdgeApi.APIActivityUpdateBatchSize = edge.DefaultEdgeApiActivityUpdateBatchSize data.Controller.EdgeApi.APIActivityUpdateInterval = edge.DefaultEdgeAPIActivityUpdateInterval diff --git a/ziti/cmd/database/add_debug_admin.go b/ziti/cmd/database/add_debug_admin.go index d69fdc587..017d2c356 100644 --- a/ziti/cmd/database/add_debug_admin.go +++ b/ziti/cmd/database/add_debug_admin.go @@ -25,7 +25,6 @@ import ( "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/db" "github.com/openziti/ziti/controller/model" - "github.com/openziti/ziti/controller/network" "github.com/spf13/cobra" ) @@ -42,9 +41,8 @@ func NewAddDebugAdminAction() *cobra.Command { } type addDebugAdminAction struct { - db boltz.Db - stores *db.Stores - managers *network.Managers + db boltz.Db + stores *db.Stores } func (action *addDebugAdminAction) GetDb() boltz.Db { @@ -55,10 +53,6 @@ func (action *addDebugAdminAction) GetStores() *db.Stores { return action.stores } -func (action *addDebugAdminAction) GetManagers() *network.Managers { - return action.managers -} - func (action *addDebugAdminAction) noError(err error) { if err != nil { panic(err) @@ -69,23 +63,14 @@ func (action *addDebugAdminAction) run(dbFile, username, password string) { boltDb, err := db.Open(dbFile) action.noError(err) - fabricStores, err := db.InitStores(boltDb, command.NoOpRateLimiter{}) + stores, err := db.InitStores(boltDb, command.NoOpRateLimiter{}) action.noError(err) - dispatcher := &command.LocalDispatcher{ - EncodeDecodeCommands: false, - } - controllers := network.NewManagers(nil, dispatcher, boltDb, fabricStores, nil) - dbProvider := &addDebugAdminAction{ - db: boltDb, - stores: fabricStores, - managers: controllers, + db: boltDb, + stores: stores, } - stores, err := db.InitStores(boltDb, command.NoOpRateLimiter{}) - action.noError(err) - id := "debug-admin" name := fmt.Sprintf("debug admin (%v)", uuid.NewString()) ctx := change.New().SetChangeAuthorType("cli.debug-db").NewMutateContext() diff --git a/ziti/controller/delete_sessions.go b/ziti/controller/delete_sessions.go index 0cb844eeb..b9f353eea 100644 --- a/ziti/controller/delete_sessions.go +++ b/ziti/controller/delete_sessions.go @@ -20,7 +20,7 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/storage/boltz" "github.com/openziti/ziti/common/version" - "github.com/openziti/ziti/controller" + "github.com/openziti/ziti/controller/config" fabricdb "github.com/openziti/ziti/controller/db" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -56,7 +56,7 @@ const ( ) func deleteSessionsFromConfig(_ *cobra.Command, args []string) { - if config, err := controller.LoadConfig(args[0]); err == nil { + if config, err := config.LoadConfig(args[0]); err == nil { deleteSessions(config.Db) } else { panic(err) diff --git a/ziti/controller/run.go b/ziti/controller/run.go index f901254d5..5b9838722 100644 --- a/ziti/controller/run.go +++ b/ziti/controller/run.go @@ -18,6 +18,7 @@ package controller import ( "fmt" + "github.com/openziti/ziti/controller/config" "net" "os" "os/signal" @@ -25,9 +26,9 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/agent" - "github.com/openziti/ziti/controller/server" - "github.com/openziti/ziti/controller" "github.com/openziti/ziti/common/version" + "github.com/openziti/ziti/controller" + "github.com/openziti/ziti/controller/server" "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -50,7 +51,7 @@ func run(cmd *cobra.Command, args []string) { WithField("build-date", version.GetBuildDate()). WithField("revision", version.GetRevision()) - config, err := controller.LoadConfig(args[0]) + config, err := config.LoadConfig(args[0]) if err != nil { startLogger.WithError(err).Error("error starting ziti-controller") panic(err) @@ -65,7 +66,7 @@ func run(cmd *cobra.Command, args []string) { panic(err) } - edgeController, err := server.NewController(config, fabricController) + edgeController, err := server.NewController(fabricController) if err != nil { panic(err) diff --git a/zititest/zitilab/models/db_builder.go b/zititest/zitilab/models/db_builder.go index d8c3ca858..8a64d6a38 100644 --- a/zititest/zitilab/models/db_builder.go +++ b/zititest/zitilab/models/db_builder.go @@ -6,7 +6,6 @@ import ( "github.com/openziti/storage/boltz" "github.com/openziti/ziti/controller/command" "github.com/openziti/ziti/controller/db" - "github.com/openziti/ziti/controller/network" "github.com/openziti/ziti/zititest/zitilab" "github.com/pkg/errors" "go.etcd.io/bbolt" @@ -37,10 +36,6 @@ func (self *ZitiDbBuilder) GetStores() *db.Stores { return self.stores } -func (self *ZitiDbBuilder) GetManagers() *network.Managers { - panic("should not be needed") -} - func (self *ZitiDbBuilder) Build(m *model.Model) error { dbFile := self.Strategy.GetDbFile(m)