diff --git a/cmd/client.go b/cmd/client.go index 70a1489..93a5f8f 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -5,7 +5,6 @@ import ( "os" "github.com/cloudradar-monitoring/rportcli/internal/pkg/api" - "github.com/cloudradar-monitoring/rportcli/internal/pkg/client" "github.com/cloudradar-monitoring/rportcli/internal/pkg/config" @@ -18,7 +17,8 @@ import ( ) func init() { - addPaginationFlags(clientsListCmd, api.ClientsLimitDefault) + addClientsPaginationFlags(clientsListCmd) + addClientsSearchFlag(clientsListCmd) clientsCmd.AddCommand(clientsListCmd) clientCmd.Flags().StringP(controllers.ClientNameFlag, "n", "", "Get client by name") clientCmd.Flags().BoolP("all", "a", false, "Show client info with additional details") @@ -46,14 +46,9 @@ var clientsListCmd = &cobra.Command{ Format: getOutputFormat(), } - clientSearch := &client.Search{ - DataProvider: rportAPI, - } - clientsController := &controllers.ClientController{ Rport: rportAPI, ClientRenderer: cr, - ClientSearch: clientSearch, } ctx, cancel := buildContext(context.Background()) @@ -88,13 +83,9 @@ var clientCmd = &cobra.Command{ Writer: os.Stdout, Format: getOutputFormat(), } - clientSearch := &client.Search{ - DataProvider: rportAPI, - } clientsController := &controllers.ClientController{ Rport: rportAPI, ClientRenderer: cr, - ClientSearch: clientSearch, } ctx, cancel := buildContext(context.Background()) @@ -103,3 +94,12 @@ var clientCmd = &cobra.Command{ return clientsController.Client(ctx, params, clientID, clientName) }, } + +func addClientsPaginationFlags(cmd *cobra.Command) { + cmd.Flags().IntP(api.PaginationLimit, "", api.ClientsLimitDefault, "Number of clients to fetch") + cmd.Flags().IntP(api.PaginationOffset, "", 0, "Offset for clients fetch") +} + +func addClientsSearchFlag(cmd *cobra.Command) { + cmd.Flags().StringP("search", "", "", "Search clients on all fields, supports wildcards (*).") +} diff --git a/cmd/command.go b/cmd/command.go index d827564..df710fc 100644 --- a/cmd/command.go +++ b/cmd/command.go @@ -11,7 +11,6 @@ import ( "github.com/breathbath/go_utils/v2/pkg/env" options "github.com/breathbath/go_utils/v2/pkg/config" - "github.com/cloudradar-monitoring/rportcli/internal/pkg/client" "github.com/cloudradar-monitoring/rportcli/internal/pkg/output" @@ -75,9 +74,6 @@ var executeCmd = &cobra.Command{ } rportAPI := buildRport(params) - clientSearch := &client.Search{ - DataProvider: rportAPI, - } isFullJobOutput := params.ReadBool(controllers.IsFullOutput, false) cmdExecutor := &controllers.CommandsController{ @@ -88,7 +84,7 @@ var executeCmd = &cobra.Command{ Format: getOutputFormat(), IsFullOutput: isFullJobOutput, }, - ClientSearch: clientSearch, + Rport: rportAPI, }, } @@ -105,10 +101,10 @@ func getCommandRequirements() []config.ParameterRequirement { Help: "Enter comma separated client IDs", Validate: config.RequiredValidate, Description: "[required] Comma separated client ids for which the command should be executed. " + - "Alternatively use -n to execute a command by client name(s)", + "Alternatively use -n to execute a command by client name(s), or use --search flag.", ShortName: "d", IsEnabled: func(providedParams *options.ParameterBag) bool { - return providedParams.ReadString(controllers.ClientNameFlag, "") == "" + return providedParams.ReadString(controllers.ClientNameFlag, "") == "" && providedParams.ReadString(controllers.SearchFlag, "") == "" }, IsRequired: true, }, @@ -117,6 +113,10 @@ func getCommandRequirements() []config.ParameterRequirement { Description: "Comma separated client names for which the command should be executed", ShortName: "n", }, + { + Field: controllers.SearchFlag, + Description: "Search clients on all fields, supports wildcards (*).", + }, { Field: controllers.Command, Help: "Enter command", diff --git a/cmd/pagination.go b/cmd/pagination.go deleted file mode 100644 index d911961..0000000 --- a/cmd/pagination.go +++ /dev/null @@ -1,11 +0,0 @@ -package cmd - -import ( - "github.com/cloudradar-monitoring/rportcli/internal/pkg/api" - "github.com/spf13/cobra" -) - -func addPaginationFlags(cmd *cobra.Command, defaultLimit int) { - cmd.Flags().IntP(api.PaginationLimit, "", defaultLimit, "Number of items to fetch") - cmd.Flags().IntP(api.PaginationOffset, "", 0, "Offset for fetch") -} diff --git a/cmd/script.go b/cmd/script.go index c32c5ac..71cdea6 100644 --- a/cmd/script.go +++ b/cmd/script.go @@ -11,7 +11,6 @@ import ( "github.com/breathbath/go_utils/v2/pkg/env" options "github.com/breathbath/go_utils/v2/pkg/config" - "github.com/cloudradar-monitoring/rportcli/internal/pkg/client" "github.com/cloudradar-monitoring/rportcli/internal/pkg/output" @@ -75,9 +74,6 @@ var executeScript = &cobra.Command{ } rportAPI := buildRport(params) - clientSearch := &client.Search{ - DataProvider: rportAPI, - } isFullJobOutput := params.ReadBool(controllers.IsFullOutput, false) cmdExecutor := &controllers.ScriptsController{ @@ -88,7 +84,7 @@ var executeScript = &cobra.Command{ Format: getOutputFormat(), IsFullOutput: isFullJobOutput, }, - ClientSearch: clientSearch, + Rport: rportAPI, }, } @@ -105,10 +101,10 @@ func getScriptRequirements() []config.ParameterRequirement { Help: "Enter comma separated client IDs", Validate: config.RequiredValidate, Description: "[required] Comma separated client ids on which the script should be executed. " + - "Alternatively use -n to execute a script by client name(s)", + "Alternatively use -n to execute a script by client name(s), or use --search flag.", ShortName: "d", IsEnabled: func(providedParams *options.ParameterBag) bool { - return providedParams.ReadString(controllers.ClientNameFlag, "") == "" + return providedParams.ReadString(controllers.ClientNameFlag, "") == "" && providedParams.ReadString(controllers.SearchFlag, "") == "" }, IsRequired: true, }, @@ -117,6 +113,10 @@ func getScriptRequirements() []config.ParameterRequirement { Description: "Comma separated client names on which the script should be executed", ShortName: "n", }, + { + Field: controllers.SearchFlag, + Description: "Search clients on all fields, supports wildcards (*).", + }, { Field: controllers.Script, Help: "Enter script path", diff --git a/cmd/tunnel.go b/cmd/tunnel.go index b2bb82c..9cb3675 100644 --- a/cmd/tunnel.go +++ b/cmd/tunnel.go @@ -7,11 +7,8 @@ import ( "os/signal" "syscall" - "github.com/cloudradar-monitoring/rportcli/internal/pkg/api" "github.com/cloudradar-monitoring/rportcli/internal/pkg/rdp" - "github.com/cloudradar-monitoring/rportcli/internal/pkg/client" - options "github.com/breathbath/go_utils/v2/pkg/config" "github.com/cloudradar-monitoring/rportcli/internal/pkg/config" @@ -24,6 +21,10 @@ import ( ) func init() { + addClientsPaginationFlags(tunnelListCmd) + addClientsSearchFlag(tunnelListCmd) + tunnelListCmd.Flags().StringP(controllers.ClientNameFlag, "n", "", "Get tunnels of a client by name") + tunnelListCmd.Flags().StringP(controllers.ClientID, "c", "", "Get tunnels of a client by client id") tunnelsCmd.AddCommand(tunnelListCmd) config.DefineCommandInputs(tunnelDeleteCmd, getDeleteTunnelRequirements()) @@ -32,10 +33,6 @@ func init() { config.DefineCommandInputs(tunnelCreateCmd, getCreateTunnelRequirements()) tunnelsCmd.AddCommand(tunnelCreateCmd) - addPaginationFlags(tunnelListCmd, api.ClientsLimitDefault) - tunnelListCmd.Flags().StringP(controllers.ClientNameFlag, "n", "", "Get tunnels of a client by name") - tunnelListCmd.Flags().StringP(controllers.ClientID, "c", "", "Get tunnels of a client by client id") - rootCmd.AddCommand(tunnelsCmd) } @@ -60,15 +57,10 @@ var tunnelListCmd = &cobra.Command{ Format: getOutputFormat(), } - clientSearch := &client.Search{ - DataProvider: rportAPI, - } - tunnelController := &controllers.TunnelController{ Rport: rportAPI, TunnelRenderer: tr, IPProvider: rportAPI, - ClientSearch: clientSearch, SSHFunc: utils.RunSSH, RDPWriter: &rdp.FileWriter{}, RDPExecutor: &rdp.Executor{ @@ -303,10 +295,6 @@ func createTunnelController(params *options.ParameterBag) *controllers.TunnelCon Format: getOutputFormat(), } - clientSearch := &client.Search{ - DataProvider: rportAPI, - } - rdpExecutor := &rdp.Executor{ CommandProvider: rdp.CommandProvider, StdErr: os.Stderr, @@ -316,7 +304,6 @@ func createTunnelController(params *options.ParameterBag) *controllers.TunnelCon Rport: rportAPI, TunnelRenderer: tr, IPProvider: rportAPI, - ClientSearch: clientSearch, SSHFunc: utils.RunSSH, RDPWriter: &rdp.FileWriter{}, RDPExecutor: rdpExecutor, diff --git a/internal/pkg/api/clients.go b/internal/pkg/api/clients.go index 486e038..7d12828 100644 --- a/internal/pkg/api/clients.go +++ b/internal/pkg/api/clients.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" url2 "net/url" @@ -12,6 +13,7 @@ import ( const ( ClientsURL = "/api/v1/clients" + ClientURL = "/api/v1/clients/%s" ClientsLimitDefault = 50 ClientsLimitMax = 500 ) @@ -20,15 +22,16 @@ type ClientsResponse struct { Data []*models.Client } -func (rp *Rport) Clients(ctx context.Context, pagination Pagination) (cr *ClientsResponse, err error) { +func (rp *Rport) Clients(ctx context.Context, pagination Pagination, filters Filters) (cr *ClientsResponse, err error) { var req *http.Request u, err := url2.Parse(url.JoinURL(rp.BaseURL, ClientsURL)) if err != nil { return nil, err } q := u.Query() - q.Set("fields[clients]", "id,name,timezone,tunnels,address,hostname,os_kernel,connection_state") + q.Set("fields[clients]", "id,name,timezone,tunnels,address,hostname,os_kernel,connection_state,disconnected_at") pagination.Apply(q) + filters.Apply(q) u.RawQuery = q.Encode() req, err = http.NewRequestWithContext( @@ -47,12 +50,33 @@ func (rp *Rport) Clients(ctx context.Context, pagination Pagination) (cr *Client return } -func (rp *Rport) GetClients(ctx context.Context) (cls []*models.Client, err error) { - var cr *ClientsResponse - cr, err = rp.Clients(ctx, NewPaginationWithLimit(ClientsLimitMax)) +type ClientResponse struct { + Data *models.Client +} +func (rp *Rport) Client(ctx context.Context, id string) (*models.Client, error) { + var req *http.Request + u, err := url2.Parse(url.JoinURL(rp.BaseURL, fmt.Sprintf(ClientURL, id))) if err != nil { - return + return nil, err + } + q := u.Query() + u.RawQuery = q.Encode() + + req, err = http.NewRequestWithContext( + ctx, + http.MethodGet, + u.String(), + nil, + ) + if err != nil { + return nil, err + } + + cr := &ClientResponse{} + _, err = rp.CallBaseClient(req, cr) + if err != nil { + return nil, err } return cr.Data, nil diff --git a/internal/pkg/api/clients_test.go b/internal/pkg/api/clients_test.go index b669b26..42e8502 100644 --- a/internal/pkg/api/clients_test.go +++ b/internal/pkg/api/clients_test.go @@ -11,6 +11,7 @@ import ( "github.com/cloudradar-monitoring/rportcli/internal/pkg/utils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var clientsStub = []*models.Client{ @@ -83,7 +84,7 @@ func TestClientsList(t *testing.T) { authHeader := r.Header.Get("Authorization") assert.Equal(t, "Basic bG9nMTE2Njo1NjQzMjI=", authHeader) - assert.Equal(t, ClientsURL+"?fields%5Bclients%5D=id%2Cname%2Ctimezone%2Ctunnels%2Caddress%2Chostname%2Cos_kernel%2Cconnection_state&page%5Blimit%5D=500&page%5Boffset%5D=0", r.URL.String()) + assert.Equal(t, ClientsURL+"?fields%5Bclients%5D=id%2Cname%2Ctimezone%2Ctunnels%2Caddress%2Chostname%2Cos_kernel%2Cconnection_state%2Cdisconnected_at&filter%5Bname%5D=abc&page%5Blimit%5D=500&page%5Boffset%5D=0", r.URL.String()) jsonEnc := json.NewEncoder(rw) e := jsonEnc.Encode(ClientsResponse{Data: clientsStub}) assert.NoError(t, e) @@ -98,36 +99,34 @@ func TestClientsList(t *testing.T) { }, }) - clientsResp, err := cl.Clients(context.Background(), NewPaginationWithLimit(ClientsLimitMax)) - assert.NoError(t, err) - if err != nil { - return - } + clientsResp, err := cl.Clients(context.Background(), NewPaginationWithLimit(ClientsLimitMax), NewFilters("name", "abc")) + require.NoError(t, err) assert.Equal(t, clientsStub, clientsResp.Data) } -func TestGetClientsList(t *testing.T) { +func TestClientGet(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + assert.Equal(t, "Basic bG9nMTE2Njo1NjQzMjI=", authHeader) + + assert.Equal(t, ClientsURL+"/test-client", r.URL.String()) jsonEnc := json.NewEncoder(rw) - e := jsonEnc.Encode(ClientsResponse{Data: clientsStub}) + e := jsonEnc.Encode(ClientResponse{Data: clientsStub[0]}) assert.NoError(t, e) })) defer srv.Close() cl := New(srv.URL, &utils.StorageBasicAuth{ AuthProvider: func() (login, pass string, err error) { - login = "log1155" - pass = "564314" + login = "log1166" + pass = "564322" return }, }) - actualClients, err := cl.GetClients(context.Background()) - assert.NoError(t, err) - - expectedClients := make([]*models.Client, 0, len(clientsStub)) - expectedClients = append(expectedClients, clientsStub...) + client, err := cl.Client(context.Background(), "test-client") + require.NoError(t, err) - assert.Equal(t, expectedClients, actualClients) + assert.Equal(t, clientsStub[0], client) } diff --git a/internal/pkg/api/filters.go b/internal/pkg/api/filters.go new file mode 100644 index 0000000..f8bee2b --- /dev/null +++ b/internal/pkg/api/filters.go @@ -0,0 +1,25 @@ +package api + +import ( + "fmt" + "net/url" +) + +type Filters map[string]string + +// NewFilters constructs filters from key value pairs. Keys with empty values are ignored. +func NewFilters(keyValues ...string) Filters { + f := make(map[string]string) + for i := 0; 2*i+1 < len(keyValues); i++ { + f[keyValues[2*i]] = keyValues[2*i+1] + } + return f +} + +func (f Filters) Apply(q url.Values) { + for k, v := range f { + if v != "" { + q.Set(fmt.Sprintf("filter[%s]", k), v) + } + } +} diff --git a/internal/pkg/api/filters_test.go b/internal/pkg/api/filters_test.go new file mode 100644 index 0000000..dda061a --- /dev/null +++ b/internal/pkg/api/filters_test.go @@ -0,0 +1,23 @@ +package api + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilters(t *testing.T) { + filters := NewFilters( + "name", "johny", + "id", "", + "*", "*abc*", + ) + + q := url.Values{} + filters.Apply(q) + + assert.Equal(t, q.Get("filter[name]"), "johny") + assert.False(t, q.Has("id")) + assert.Equal(t, q.Get("filter[*]"), "*abc*") +} diff --git a/internal/pkg/client/search.go b/internal/pkg/client/search.go deleted file mode 100644 index e497623..0000000 --- a/internal/pkg/client/search.go +++ /dev/null @@ -1,71 +0,0 @@ -package client - -import ( - "context" - "fmt" - "strings" - - options "github.com/breathbath/go_utils/v2/pkg/config" - - "github.com/cloudradar-monitoring/rportcli/internal/pkg/models" -) - -type DataProvider interface { - GetClients(ctx context.Context) (cls []*models.Client, err error) -} - -type Search struct { - DataProvider DataProvider -} - -func (s *Search) Search(ctx context.Context, term string, params *options.ParameterBag) (foundCls []*models.Client, err error) { - cls, err := s.DataProvider.GetClients(ctx) - if err != nil { - return foundCls, err - } - - foundCls = s.findInClientsList(cls, term) - return -} - -func (s *Search) FindOne(ctx context.Context, searchTerm string, params *options.ParameterBag) (*models.Client, error) { - clients, err := s.Search(ctx, searchTerm, params) - if err != nil { - return &models.Client{}, err - } - - if len(clients) == 0 { - return &models.Client{}, fmt.Errorf("unknown client '%s'", searchTerm) - } - - if len(clients) == 1 { - return clients[0], nil - } - - return &models.Client{}, fmt.Errorf("client identified by '%s' is ambiguous, use a more precise name or use the client id", searchTerm) -} - -func (s *Search) findInClientsList(cls []*models.Client, term string) (foundCls []*models.Client) { - terms := strings.Split(term, ",") - for i := range terms { - terms[i] = strings.ToLower(terms[i]) - } - - foundCls = make([]*models.Client, 0) - for i := range cls { - cl := cls[i] - curClientName := strings.ToLower(cl.Name) - curClientID := strings.ToLower(cl.ID) - - for i := range terms { - curTerm := terms[i] - if strings.HasPrefix(curClientName, curTerm) { - foundCls = append(foundCls, cl) - } else if strings.HasPrefix(curClientID, curTerm) { - foundCls = append(foundCls, cl) - } - } - } - - return -} diff --git a/internal/pkg/client/search_test.go b/internal/pkg/client/search_test.go deleted file mode 100644 index 03ead0f..0000000 --- a/internal/pkg/client/search_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package client - -import ( - "context" - "errors" - "testing" - - options "github.com/breathbath/go_utils/v2/pkg/config" - - "github.com/cloudradar-monitoring/rportcli/internal/pkg/models" - "github.com/stretchr/testify/assert" -) - -var clientsList = []*models.Client{ - { - ID: "1", - Name: "my tiny client", - }, - { - ID: "2", - Name: "my Tiny nice client", - }, - { - ID: "3", - Name: "$100 usd client", - }, -} - -type DataProviderMock struct { - clientsToGive []*models.Client - errToGive error -} - -func (dpm *DataProviderMock) GetClients(ctx context.Context) (cls []*models.Client, err error) { - return dpm.clientsToGive, dpm.errToGive -} - -func TestFindClientsFromDataProvider(t *testing.T) { - search := Search{ - DataProvider: &DataProviderMock{ - clientsToGive: clientsList, - }, - } - - foundCls, err := search.Search(context.Background(), "my tiny", &options.ParameterBag{}) - assert.NoError(t, err) - assert.Len(t, foundCls, 2) - assert.Equal(t, foundCls, []*models.Client{ - { - ID: "1", - Name: "my tiny client", - }, - { - ID: "2", - Name: "my Tiny nice client", - }, - }) - - foundCls2, err2 := search.Search(context.Background(), "my tiny client,$100 usd client", &options.ParameterBag{}) - assert.NoError(t, err2) - assert.Equal(t, foundCls2, []*models.Client{ - { - ID: "1", - Name: "my tiny client", - }, - { - ID: "3", - Name: "$100 usd client", - }, - }) -} - -func TestDataProviderError(t *testing.T) { - search := Search{ - DataProvider: &DataProviderMock{ - clientsToGive: clientsList, - errToGive: errors.New("some load error"), - }, - } - - _, err := search.Search(context.Background(), "$100", &options.ParameterBag{}) - assert.EqualError(t, err, "some load error") -} - -func TestFindByAmbiguousClientName(t *testing.T) { - search := Search{ - DataProvider: &DataProviderMock{ - clientsToGive: clientsList, - }, - } - - _, err := search.FindOne(context.Background(), "my tiny", &options.ParameterBag{}) - assert.EqualError(t, err, `client identified by 'my tiny' is ambiguous, use a more precise name or use the client id`) -} diff --git a/internal/pkg/controllers/client.go b/internal/pkg/controllers/client.go index 59c3885..57e864b 100644 --- a/internal/pkg/controllers/client.go +++ b/internal/pkg/controllers/client.go @@ -12,6 +12,7 @@ import ( const ( ClientNameFlag = "name" + SearchFlag = "search" ) type ClientRenderer interface { @@ -21,12 +22,15 @@ type ClientRenderer interface { type ClientController struct { Rport *api.Rport - ClientSearch ClientSearch ClientRenderer ClientRenderer } func (cc *ClientController) Clients(ctx context.Context, params *options.ParameterBag) error { - clResp, err := cc.Rport.Clients(ctx, api.NewPaginationFromParams(params)) + clResp, err := cc.Rport.Clients( + ctx, + api.NewPaginationFromParams(params), + api.NewFilters("*", params.ReadString(SearchFlag, "")), + ) if err != nil { return err } @@ -42,22 +46,23 @@ func (cc *ClientController) Client(ctx context.Context, params *options.Paramete renderDetails := params.ReadBool("all", false) if id != "" { - clResp, err := cc.Rport.Clients(ctx, api.NewPaginationWithLimit(api.ClientsLimitMax)) + client, err := cc.Rport.Client(ctx, id) if err != nil { return err } - for _, cl := range clResp.Data { - if cl.ID == id { - return cc.ClientRenderer.RenderClient(cl, renderDetails) - } - } - } else { - cl, err := cc.ClientSearch.FindOne(ctx, name, params) - if err != nil { - return err - } - return cc.ClientRenderer.RenderClient(cl, renderDetails) + return cc.ClientRenderer.RenderClient(client, renderDetails) + } + + clients, err := cc.Rport.Clients(ctx, api.NewPaginationWithLimit(2), api.NewFilters("name", name)) + if err != nil { + return err + } + if len(clients.Data) < 1 { + return fmt.Errorf("unknown client with name %q", name) + } + if len(clients.Data) > 1 { + return fmt.Errorf("client with name %q is ambiguous, use a more precise name or use the client id", name) } - return fmt.Errorf("client not found by the provided id '%s' or name '%s'", id, name) + return cc.ClientRenderer.RenderClient(clients.Data[0], renderDetails) } diff --git a/internal/pkg/controllers/client_test.go b/internal/pkg/controllers/client_test.go index 97c9142..71794bc 100644 --- a/internal/pkg/controllers/client_test.go +++ b/internal/pkg/controllers/client_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" @@ -96,7 +95,7 @@ func TestClientsController(t *testing.T) { } func TestClientFoundByIDController(t *testing.T) { - srv := startClientsServer() + srv := startClientServer() defer srv.Close() cl := api.New(srv.URL, nil) @@ -132,15 +131,9 @@ func TestClientFoundByNameController(t *testing.T) { cl := api.New(srv.URL, nil) buf := bytes.Buffer{} - clSearch := &ClientSearchMock{ - searchTermGiven: "", - clientsToGive: []*models.Client{clientStub}, - errorToGive: nil, - } clController := ClientController{ Rport: cl, ClientRenderer: &ClientRendererMock{Writer: &buf}, - ClientSearch: clSearch, } err := clController.Client(context.Background(), &options.ParameterBag{}, "", "Client 123") @@ -156,28 +149,6 @@ func TestClientFoundByNameController(t *testing.T) { ) } -func TestClientNotFoundController(t *testing.T) { - srv := startClientsServer() - defer srv.Close() - - cl := api.New(srv.URL, nil) - buf := bytes.Buffer{} - - clController := ClientController{ - Rport: cl, - ClientRenderer: &ClientRendererMock{Writer: &buf}, - ClientSearch: &ClientSearchMock{ - errorToGive: errors.New("client not found by the provided id '434' or name ''"), - }, - } - - err := clController.Client(context.Background(), &options.ParameterBag{}, "434", "") - assert.EqualError(t, err, `client not found by the provided id '434' or name ''`) - - err = clController.Client(context.Background(), &options.ParameterBag{}, "", "some unknown name") - assert.EqualError(t, err, `client not found by the provided id '434' or name ''`) -} - func TestInvalidInputForClients(t *testing.T) { clController := ClientController{} @@ -197,3 +168,15 @@ func startClientsServer() *httptest.Server { return srv } + +func startClientServer() *httptest.Server { + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + jsonEnc := json.NewEncoder(rw) + e := jsonEnc.Encode(api.ClientResponse{Data: clientStub}) + if e != nil { + rw.WriteHeader(500) + } + })) + + return srv +} diff --git a/internal/pkg/controllers/commands_test.go b/internal/pkg/controllers/commands_test.go index a4d77dd..6221c30 100644 --- a/internal/pkg/controllers/commands_test.go +++ b/internal/pkg/controllers/commands_test.go @@ -131,93 +131,6 @@ func TestCommandExecutionByClientIDsSuccess(t *testing.T) { assert.True(t, rw.isClosed) } -func TestCommandExecutionByClientNameSuccess(t *testing.T) { - jobResp := models.Job{Jid: "987"} - jobRespBytes, err := json.Marshal(jobResp) - assert.NoError(t, err) - if err != nil { - return - } - - rw := &ReadWriterMock{ - itemsToRead: []ReadChunk{ - { - Output: jobRespBytes, - }, - { - Err: io.EOF, - }, - }, - writtenItems: []string{}, - isClosed: false, - } - - jr := &JobRendererMock{} - - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{ - { - ID: "11344", - Name: "some client 11344", - }, - { - ID: "11345", - Name: "some client 11345", - }, - }, - } - - ic := &CommandsController{ - ExecutionHelper: &ExecutionHelper{ - ReadWriter: rw, - JobRenderer: jr, - ClientSearch: searchMock, - }, - } - - params := config.FromValues(map[string]string{ - ClientNameFlag: "some client 11344,some client 11345", - Command: "cmd", - Timeout: "1", - ExecConcurrently: "1", - Interpreter: "cmd", - }) - err = ic.Start(context.Background(), params) - - assert.NoError(t, err) - - assert.Len(t, rw.writtenItems, 1) - expectedCommandInput := `{"client_ids":["11344","11345"],"is_sudo":false,"execute_concurrently":true,"abort_on_error":false,"timeout_sec":1,"command":"cmd","script":"","cwd":"","interpreter":"cmd"}` - assert.Equal(t, expectedCommandInput, rw.writtenItems[0]) -} - -func TestCommandExecutionClientNotFoundByName(t *testing.T) { - rw := &ReadWriterMock{ - itemsToRead: []ReadChunk{{Err: io.EOF}}, - writtenItems: []string{}, - } - - jr := &JobRendererMock{} - - searchMock := &ClientSearchMock{clientsToGive: []*models.Client{}} - - ic := &CommandsController{ - ExecutionHelper: &ExecutionHelper{ - ReadWriter: rw, - JobRenderer: jr, - ClientSearch: searchMock, - }, - } - - params := config.FromValues(map[string]string{ - ClientNameFlag: "some client 11349", - Command: "cmd", - }) - err := ic.Start(context.Background(), params) - - assert.EqualError(t, err, "unknown client(s) 'some client 11349'") -} - func TestInvalidInputForCommand(t *testing.T) { cc := &CommandsController{ ExecutionHelper: &ExecutionHelper{}, @@ -231,7 +144,7 @@ func TestInvalidInputForCommand(t *testing.T) { CheckPort: "1", }) err := cc.Start(context.Background(), params) - assert.EqualError(t, err, "no client id nor name provided") + assert.EqualError(t, err, "no client ids, names or search provided") } func TestCommandExecutionWithInvalidResponse(t *testing.T) { diff --git a/internal/pkg/controllers/execHelper.go b/internal/pkg/controllers/execHelper.go index a2038d1..8093764 100644 --- a/internal/pkg/controllers/execHelper.go +++ b/internal/pkg/controllers/execHelper.go @@ -13,6 +13,7 @@ import ( options "github.com/breathbath/go_utils/v2/pkg/config" io2 "github.com/breathbath/go_utils/v2/pkg/io" + "github.com/cloudradar-monitoring/rportcli/internal/pkg/api" "github.com/cloudradar-monitoring/rportcli/internal/pkg/models" "github.com/sirupsen/logrus" ) @@ -55,9 +56,9 @@ type JobRenderer interface { } type ExecutionHelper struct { - ClientSearch ClientSearch - JobRenderer JobRenderer - ReadWriter ReadWriter + JobRenderer JobRenderer + ReadWriter ReadWriter + Rport *api.Rport } func (eh *ExecutionHelper) execute(ctx context.Context, params *options.ParameterBag, scriptPayload, interpreter string) error { @@ -127,31 +128,37 @@ func (eh *ExecutionHelper) sendCommand(wsCmd *models.WsScriptCommand) error { } func (eh *ExecutionHelper) getClientIDs(ctx context.Context, params *options.ParameterBag) (clientIDs string, err error) { - clientIDs = params.ReadString(ClientIDs, "") - clientName := params.ReadString(ClientNameFlag, "") - - if clientIDs == "" && clientName == "" { - return "", errors.New("no client id nor name provided") + ids := params.ReadString(ClientIDs, "") + if ids != "" { + return ids, nil + } + names := params.ReadString(ClientNameFlag, "") + search := params.ReadString(SearchFlag, "") + if ids == "" && names == "" && search == "" { + return "", errors.New("no client ids, names or search provided") } - if clientIDs == "" { - clients, err := eh.ClientSearch.Search(ctx, clientName, params) - if err != nil { - return "", err - } - - if len(clients) == 0 { - return "", fmt.Errorf("unknown client(s) '%s'", clientName) - } + clients, err := eh.Rport.Clients( + ctx, + api.NewPaginationWithLimit(api.ClientsLimitMax), + api.NewFilters( + "name", names, + "*", search, + ), + ) + if err != nil { + return "", err + } - for i := range clients { - cl := clients[i] - clientIDs += cl.ID + "," + for _, cl := range clients.Data { + if cl.DisconnectedAt != "" { + continue } - - clientIDs = strings.Trim(clientIDs, ",") + clientIDs += cl.ID + "," } + clientIDs = strings.Trim(clientIDs, ",") + return clientIDs, nil } diff --git a/internal/pkg/controllers/mocks.go b/internal/pkg/controllers/mocks.go index 82307b6..b870b88 100644 --- a/internal/pkg/controllers/mocks.go +++ b/internal/pkg/controllers/mocks.go @@ -1,13 +1,5 @@ package controllers -import ( - "context" - - options "github.com/breathbath/go_utils/v2/pkg/config" - - "github.com/cloudradar-monitoring/rportcli/internal/pkg/models" -) - type PromptReaderMock struct { ReadCount int PasswordReadCount int @@ -40,22 +32,3 @@ func (prm *PromptReaderMock) ReadPassword() (string, error) { func (prm *PromptReaderMock) Output(text string) { prm.Inputs = append(prm.Inputs, text) } - -type ClientSearchMock struct { - searchTermGiven string - clientsToGive []*models.Client - errorToGive error -} - -func (csm *ClientSearchMock) Search(ctx context.Context, term string, params *options.ParameterBag) (foundCls []*models.Client, err error) { - csm.searchTermGiven = term - return csm.clientsToGive, csm.errorToGive -} - -func (csm *ClientSearchMock) FindOne(ctx context.Context, searchTerm string, params *options.ParameterBag) (*models.Client, error) { - csm.searchTermGiven = searchTerm - if len(csm.clientsToGive) > 0 { - return csm.clientsToGive[0], csm.errorToGive - } - return &models.Client{}, csm.errorToGive -} diff --git a/internal/pkg/controllers/tunnel.go b/internal/pkg/controllers/tunnel.go index 7f3fdca..5402ffb 100644 --- a/internal/pkg/controllers/tunnel.go +++ b/internal/pkg/controllers/tunnel.go @@ -7,7 +7,6 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/cloudradar-monitoring/rportcli/internal/pkg/config" @@ -63,36 +62,25 @@ type TunnelController struct { Rport *api.Rport TunnelRenderer TunnelRenderer IPProvider IPProvider - ClientSearch ClientSearch SSHFunc func(sshParams []string) error RDPWriter RDPFileWriter RDPExecutor RDPExecutor } func (tc *TunnelController) Tunnels(ctx context.Context, params *options.ParameterBag) error { - clientID := params.ReadString(ClientID, "") - clientName := params.ReadString(ClientNameFlag, "") - - var err error - var clients []*models.Client - if clientID != "" || clientName != "" { - searchTerm := clientID - if clientName != "" { - searchTerm = clientName - } - - clients, err = tc.ClientSearch.Search(ctx, searchTerm, params) - if err != nil { - return err - } - } else { - var clResp *api.ClientsResponse - clResp, err = tc.Rport.Clients(ctx, api.NewPaginationFromParams(params)) - if err != nil { - return err - } - clients = clResp.Data + clResp, err := tc.Rport.Clients( + ctx, + api.NewPaginationFromParams(params), + api.NewFilters( + "id", params.ReadString(ClientID, ""), + "name", params.ReadString(ClientNameFlag, ""), + "*", params.ReadString(SearchFlag, ""), + ), + ) + if err != nil { + return err } + clients := clResp.Data tunnels := make([]*models.Tunnel, 0) for _, cl := range clients { @@ -107,31 +95,13 @@ func (tc *TunnelController) Tunnels(ctx context.Context, params *options.Paramet } func (tc *TunnelController) Delete(ctx context.Context, params *options.ParameterBag) error { - clientID := params.ReadString(ClientID, "") - tunnelID := params.ReadString(TunnelID, "") - clientName := params.ReadString(ClientNameFlag, "") - - if clientID == "" && clientName == "" { - return errors.New("no client id nor name provided") - } - - if clientID == "" { - clients, err := tc.ClientSearch.Search(ctx, clientName, params) - if err != nil { - return err - } - - if len(clients) == 0 { - return fmt.Errorf("unknown client '%s'", clientName) - } - - if len(clients) != 1 { - return fmt.Errorf("client identified by '%s' is ambiguous, use a more precise name or use the client id", clientName) - } - clientID = clients[0].ID + clientID, _, err := tc.getClientIDAndClientName(ctx, params) + if err != nil { + return err } - err := tc.Rport.DeleteTunnel(ctx, clientID, tunnelID, params.ReadBool(ForceDeletion, false)) + tunnelID := params.ReadString(TunnelID, "") + err = tc.Rport.DeleteTunnel(ctx, clientID, tunnelID, params.ReadBool(ForceDeletion, false)) if err != nil { if strings.Contains(err.Error(), "tunnel is still active") { return fmt.Errorf("%v, use -f to delete it anyway", err) @@ -157,17 +127,29 @@ func (tc *TunnelController) getClientIDAndClientName( err = errors.New("no client id nor name provided") return } + if clientID != "" && clientName != "" { + err = errors.New("both client id and name provided") + return + } if clientID != "" { return } - client, err := tc.ClientSearch.FindOne(ctx, clientName, params) + clients, err := tc.Rport.Clients(ctx, api.NewPaginationWithLimit(2), api.NewFilters("name", clientName)) if err != nil { return } - return client.ID, clientName, nil + if len(clients.Data) < 1 { + return "", "", fmt.Errorf("unknown client with name %q", clientName) + } + if len(clients.Data) > 1 { + return "", "", fmt.Errorf("client with name %q is ambidguous, use a more precise name or use the client id", clientName) + } + + client := clients.Data[0] + return client.ID, client.Name, nil } func (tc *TunnelController) Create(ctx context.Context, params *options.ParameterBag) error { @@ -305,7 +287,7 @@ func (tc *TunnelController) launchHelperFlowIfNeeded( return tc.startSSHFlow(ctx, tunnelCreated, params, deleteTunnelParams) } - return tc.startRDPFlow(ctx, tunnelCreated, params, clientName, clientID) + return tc.startRDPFlow(tunnelCreated, params, clientName) } func (tc *TunnelController) finishSSHFlow(ctx context.Context, deleteTunnelParams *options.ParameterBag, prevErr error) error { @@ -405,31 +387,15 @@ func (tc *TunnelController) extractPortAndHost( } func (tc *TunnelController) startRDPFlow( - ctx context.Context, tunnelCreated *models.TunnelCreated, params *options.ParameterBag, - clientName, clientID string, + clientName string, ) error { port, host, err := tc.extractPortAndHost(tunnelCreated, params) if err != nil { return err } - if clientName == "" { - logrus.Debug("since client name is not provided, will try to find a client by id " + clientID) - clients, e := tc.ClientSearch.Search(ctx, clientID, params) - if e != nil { - return e - } - if len(clients) == 0 || clients[0].Name == "" { - clientName = fmt.Sprint(time.Now().Unix()) - } else { - clientName = clients[0].Name - } - - logrus.Debugf("found client name %s", clientName) - } - rdpFileInput := models.FileInput{ Address: fmt.Sprintf("%s:%s", host, port), ScreenHeight: params.ReadInt(RDPHeight, 0), diff --git a/internal/pkg/controllers/tunnel_test.go b/internal/pkg/controllers/tunnel_test.go index f87afd0..de91ae1 100644 --- a/internal/pkg/controllers/tunnel_test.go +++ b/internal/pkg/controllers/tunnel_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" options "github.com/breathbath/go_utils/v2/pkg/config" @@ -124,7 +123,6 @@ func TestTunnelsController(t *testing.T) { isSSHExecuted = true return nil }, - ClientSearch: &ClientSearchMock{}, } assert.False(t, isSSHExecuted) @@ -141,63 +139,6 @@ func TestTunnelsController(t *testing.T) { ) } -func TestTunnelsControllerByClient(t *testing.T) { - srv := startClientsServer() - defer srv.Close() - - apiAuth := &utils.StorageBasicAuth{ - AuthProvider: func() (login, pass string, err error) { - login = "log1455" - pass = "pass1446" - return - }, - } - cl := api.New(srv.URL, apiAuth) - buf := bytes.Buffer{} - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{ - { - ID: "cl2", - Name: "client 354351", - Tunnels: []*models.Tunnel{ - { - ID: "23", - }, - }, - }, - }, - } - tController := &TunnelController{ - Rport: cl, - TunnelRenderer: &TunnelRendererMock{Writer: &buf}, - SSHFunc: func(sshParams []string) error { - return nil - }, - ClientSearch: searchMock, - } - - paramProv := options.NewMapValuesProvider(map[string]interface{}{ - ClientNameFlag: "client 354351", - }) - err := tController.Tunnels(context.Background(), options.New(paramProv)) - require.NoError(t, err) - - assert.Equal( - t, - `[{"id":"23","client_id":"cl2","client_name":"client 354351","lhost":"","lport":"","rhost":"","rport":"","lport_random":false,"scheme":"","acl":"","idle_timeout_minutes":0}]`, - buf.String(), - ) - - assert.Equal(t, "client 354351", searchMock.searchTermGiven) - - paramProv2 := options.NewMapValuesProvider(map[string]interface{}{ - ClientID: "cl2", - }) - err = tController.Tunnels(context.Background(), options.New(paramProv2)) - require.NoError(t, err) - assert.Equal(t, "cl2", searchMock.searchTermGiven) -} - func TestTunnelDeleteByClientIDController(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(t, "Basic bG9nMTU1OnBhc3MxNTU=", r.Header.Get("Authorization")) @@ -221,7 +162,6 @@ func TestTunnelDeleteByClientIDController(t *testing.T) { tController := TunnelController{ Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &buf}, - ClientSearch: &ClientSearchMock{}, SSHFunc: func(sshParams []string) error { isSSHExecuted = true return nil @@ -240,102 +180,6 @@ func TestTunnelDeleteByClientIDController(t *testing.T) { assert.Equal(t, `{"status":"Tunnel successfully deleted"}`, buf.String()) } -func TestTunnelDeleteByClientNameController(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodDelete, r.Method) - assert.Equal(t, "/api/v1/clients/cl2/tunnels/tun4", r.URL.String()) - rw.WriteHeader(http.StatusNoContent) - })) - defer srv.Close() - - apiAuth := &utils.StorageBasicAuth{ - AuthProvider: func() (login, pass string, err error) { - login = "log24124" - pass = "pass341324" - return - }, - } - cl := api.New(srv.URL, apiAuth) - buf := bytes.Buffer{} - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{ - { - ID: "cl2", - Name: "some client", - }, - }, - } - tController := TunnelController{ - Rport: cl, - TunnelRenderer: &TunnelRendererMock{Writer: &buf}, - ClientSearch: searchMock, - SSHFunc: func(sshParams []string) error { - return nil - }, - } - - params := options.New(options.NewMapValuesProvider(map[string]interface{}{ - ClientID: "", - TunnelID: "tun4", - ClientNameFlag: "some client", - })) - - err := tController.Delete(context.Background(), params) - assert.NoError(t, err) - assert.Equal(t, `{"status":"Tunnel successfully deleted"}`, buf.String()) -} - -func TestTunnelDeleteByAmbiguousClientName(t *testing.T) { - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{ - { - ID: "cl1", - Name: "some client 1", - }, - { - ID: "cl2", - Name: "some client 2", - }, - }, - } - tController := TunnelController{ - ClientSearch: searchMock, - SSHFunc: func(sshParams []string) error { - return nil - }, - } - - params := options.New(options.NewMapValuesProvider(map[string]interface{}{ - ClientID: "", - TunnelID: "tun3", - ClientNameFlag: "some client", - })) - - err := tController.Delete(context.Background(), params) - assert.EqualError(t, err, `client identified by 'some client' is ambiguous, use a more precise name or use the client id`) -} - -func TestTunnelDeleteNotFoundClientName(t *testing.T) { - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{}, - } - tController := TunnelController{ - ClientSearch: searchMock, - SSHFunc: func(sshParams []string) error { - return nil - }, - } - - params := options.New(options.NewMapValuesProvider(map[string]interface{}{ - ClientID: "", - TunnelID: "tun5", - ClientNameFlag: "some client", - })) - - err := tController.Delete(context.Background(), params) - assert.EqualError(t, err, `unknown client 'some client'`) -} - func TestInvalidInputForTunnelDelete(t *testing.T) { tController := TunnelController{} params := options.New(options.NewMapValuesProvider(map[string]interface{}{ @@ -413,76 +257,6 @@ func TestTunnelCreateWithClientID(t *testing.T) { assert.Equal(t, expectedOutput, buf.String()) } -func TestTunnelCreateWithClientName(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/v1/clients/444/tunnels?acl=3.4.5.7&check_port=1&local=lohost2%3A3301&remote=rhost4%3A3345&scheme=ssh&skip-idle-timeout=1", r.URL.String()) - jsonEnc := json.NewEncoder(rw) - e := jsonEnc.Encode(api.TunnelCreatedResponse{Data: &models.TunnelCreated{ - ID: "444", - Lhost: "lohost2", - Lport: "3301", - Rhost: "rhost4", - Rport: "3345", - LportRandom: true, - Scheme: utils.SSH, - ACL: "3.4.5.7", - }}) - assert.NoError(t, e) - })) - defer srv.Close() - - apiAuth := &utils.StorageBasicAuth{ - AuthProvider: func() (login, pass string, err error) { return "someloggg", "somepaaas", nil }, - } - - buf := bytes.Buffer{} - - cl := api.New(srv.URL, apiAuth) - - searchMock := &ClientSearchMock{ - clientsToGive: []*models.Client{ - { - ID: "444", - Name: "some client 444", - }, - }, - } - - isSSHExecuted := false - tController := TunnelController{ - Rport: cl, - TunnelRenderer: &TunnelRendererMock{Writer: &buf}, - IPProvider: IPProviderMock{ - IP: "3.4.5.7", - }, - ClientSearch: searchMock, - SSHFunc: func(sshParams []string) error { - isSSHExecuted = true - return nil - }, - } - assert.False(t, isSSHExecuted) - - params := config.FromValues(map[string]string{ - ClientID: "", - ClientNameFlag: "some client 444", - Local: "lohost2:3301", - Remote: "rhost4:3345", - Scheme: utils.SSH, - CheckPort: "1", - config.ServerURL: "http://11.11.11.11:33", - SkipIdleTimeout: "1", - }) - err := tController.Create(context.Background(), params) - assert.NoError(t, err) - - expectedOutput := fmt.Sprintf( - `{"id":"444","client_id":"some client 444","client_name":"","lhost":"lohost2","lport":"3301","rhost":"rhost4","rport":"3345","lport_random":true,"scheme":"ssh","acl":"3.4.5.7","usage":"ssh -p 3301 11.11.11.11 -l ${USER}","idle_timeout_minutes":0,"rport_server":"%s"}`, - srv.URL, - ) - assert.Equal(t, expectedOutput, buf.String()) -} - func TestInvalidInputForTunnelCreate(t *testing.T) { tController := TunnelController{} params := config.FromValues(map[string]string{ @@ -497,21 +271,6 @@ func TestInvalidInputForTunnelCreate(t *testing.T) { assert.EqualError(t, err, "no client id nor name provided") } -func TestTunnelCreateNotFoundClientName(t *testing.T) { - searchMock := &ClientSearchMock{ - errorToGive: errors.New("unknown client 'some client'"), - } - tController := TunnelController{ - ClientSearch: searchMock, - } - - params := config.FromValues(map[string]string{ - ClientNameFlag: "some client", - }) - err := tController.Create(context.Background(), params) - assert.EqualError(t, err, `unknown client 'some client'`) -} - func TestTunnelCreateWithSchemeDiscovery(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPut { @@ -539,15 +298,12 @@ func TestTunnelCreateWithSchemeDiscovery(t *testing.T) { cl := api.New(srv.URL, apiAuth) - searchMock := &ClientSearchMock{clientsToGive: []*models.Client{}} - tController := TunnelController{ Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &buf}, IPProvider: IPProviderMock{ IP: "3.4.5.8", }, - ClientSearch: searchMock, SSHFunc: func(sshParams []string) error { return nil }, @@ -603,15 +359,12 @@ func TestTunnelCreateWithPortDiscovery(t *testing.T) { cl := api.New(srv.URL, apiAuth) - searchMock := &ClientSearchMock{clientsToGive: []*models.Client{}} - tController := TunnelController{ Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &buf}, IPProvider: IPProviderMock{ IP: "3.4.5.9", }, - ClientSearch: searchMock, SSHFunc: func(sshParams []string) error { return nil }, @@ -690,8 +443,6 @@ func TestTunnelCreateWithSSH(t *testing.T) { cl := api.New(srv.URL, apiAuth) - searchMock := &ClientSearchMock{clientsToGive: []*models.Client{}} - isSSHCalled := false tController := TunnelController{ Rport: cl, @@ -699,7 +450,6 @@ func TestTunnelCreateWithSSH(t *testing.T) { IPProvider: IPProviderMock{ IP: "3.4.5.10", }, - ClientSearch: searchMock, SSHFunc: func(sshParams []string) error { isSSHCalled = true assert.Equal(t, []string{"rport-url.com", "-p", "22", "-l", "root", "-i", "somefile"}, sshParams) @@ -772,7 +522,6 @@ func TestTunnelCreateWithSSHFailure(t *testing.T) { IPProvider: IPProviderMock{ IP: "3.4.5.16", }, - ClientSearch: &ClientSearchMock{clientsToGive: []*models.Client{}}, SSHFunc: func(sshParams []string) error { return errors.New("ssh failure") }, @@ -829,9 +578,8 @@ func TestTunnelCreateWithRDP(t *testing.T) { IPProvider: IPProviderMock{ IP: "3.4.5.166", }, - ClientSearch: &ClientSearchMock{clientsToGive: []*models.Client{}}, - RDPWriter: fileWriter, - RDPExecutor: rdpExecutor, + RDPWriter: fileWriter, + RDPExecutor: rdpExecutor, } params := config.FromValues(map[string]string{ @@ -846,6 +594,8 @@ func TestTunnelCreateWithRDP(t *testing.T) { IdleTimeoutMinutes: "5", }) err := tController.Create(context.Background(), params) + assert.NoError(t, err) + expectedFileInput := models.FileInput{ Address: "rport-url123.com:3344", ScreenHeight: 990, @@ -856,7 +606,6 @@ func TestTunnelCreateWithRDP(t *testing.T) { assert.Equal(t, expectedFileInput.ScreenHeight, fileWriter.FileInput.ScreenHeight) assert.Equal(t, expectedFileInput.ScreenWidth, fileWriter.FileInput.ScreenWidth) assert.Equal(t, expectedFileInput.UserName, fileWriter.FileInput.UserName) - assert.NoError(t, err) expectedOutput := fmt.Sprintf( `{"id":"777","client_id":"1314","client_name":"","lhost":"lohost77","lport":"3344","rhost":"","rport":"","lport_random":false,"scheme":"rdp","acl":"","usage":"rdp://rport-url123.com:3344","idle_timeout_minutes":5,"rport_server":"%s"}`, @@ -882,7 +631,6 @@ func TestTunnelCreateWithRDPIncompatibleFlags(t *testing.T) { Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &renderBuf}, IPProvider: IPProviderMock{}, - ClientSearch: &ClientSearchMock{clientsToGive: []*models.Client{}}, RDPWriter: nil, RDPExecutor: rdpExecutor, } @@ -913,7 +661,6 @@ func TestTunnelCreateWithSSHIncompatibleFlags(t *testing.T) { Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &renderBuf}, IPProvider: IPProviderMock{}, - ClientSearch: &ClientSearchMock{clientsToGive: []*models.Client{}}, SSHFunc: func(sshParams []string) error { isSSHCalled = true return nil @@ -963,7 +710,6 @@ func TestTunnelDeleteFailureWithActiveConnections(t *testing.T) { tController := TunnelController{ Rport: cl, TunnelRenderer: &TunnelRendererMock{Writer: &buf}, - ClientSearch: &ClientSearchMock{}, SSHFunc: func(sshParams []string) error { return nil }, diff --git a/internal/pkg/utils/httpClient.go b/internal/pkg/utils/httpClient.go index d5b0702..e9c9688 100644 --- a/internal/pkg/utils/httpClient.go +++ b/internal/pkg/utils/httpClient.go @@ -61,14 +61,14 @@ func (c *BaseClient) Call(req *http.Request, target interface{}, errTarget error respBodyBytes, err = ioutil.ReadAll(resp.Body) if err != nil { logrus.Warnf("failed to read response body: %v", err) - e := c.convertResponseCodeToError(resp.StatusCode, nil) + e := c.convertResponseCodeToError(resp.StatusCode) return resp, e } err = json.Unmarshal(respBodyBytes, errTarget) if err != nil { logrus.Warnf("cannot unmarshal error response %s: %v", string(respBodyBytes), err) - e := c.convertResponseCodeToError(resp.StatusCode, nil) + e := c.convertResponseCodeToError(resp.StatusCode) return resp, e } return resp, errTarget @@ -96,27 +96,15 @@ func (c *BaseClient) Call(req *http.Request, target interface{}, errTarget error return resp, nil } -func (c *BaseClient) convertResponseCodeToError(respCode int, errTarget error) (err error) { +func (c *BaseClient) convertResponseCodeToError(respCode int) (err error) { if respCode == http.StatusNotFound { err = errors.New("the specified item doesn't exist") } else if respCode == http.StatusInternalServerError { - if errTarget != nil { - err = fmt.Errorf("operation failed %s", errTarget.Error()) - } else { - err = errors.New("operation failed") - } + err = errors.New("operation failed") } else if respCode == http.StatusBadRequest { - if errTarget != nil { - err = fmt.Errorf("invalid input provided: %s", errTarget.Error()) - } else { - err = errors.New("invalid input provided") - } + err = errors.New("invalid input provided") } else { - if errTarget != nil { - err = fmt.Errorf("unknown error: %s", errTarget.Error()) - } else { - err = errors.New("unknown error") - } + err = fmt.Errorf("unknown error: %d %s", respCode, http.StatusText(respCode)) } return err