From c5cfe4ac16cc5f2126bab810477aded9719b70ef Mon Sep 17 00:00:00 2001 From: MatteoPologruto <109663225+MatteoPologruto@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:41:28 +0100 Subject: [PATCH] Add metadata retrieved from the `context` to the user agent when a new HTTP client is created (#2789) * Set the extra user agent when a new rpc instance is created * Add integration test * Moved user-agent extraction deep in configuration.HttpClient This allows the extraction of the user-agent in a single place. Also it forces the context passing on all operations that requires access to network. * Updated integration test * Restore previous user-agent for package manager * Apply review suggestions to the integration test --------- Co-authored-by: Cristian Maglie --- commands/instances.go | 8 +-- commands/service_board_identify.go | 18 +++--- commands/service_board_identify_test.go | 15 ++--- commands/service_check_for_updates.go | 6 +- commands/service_library_download.go | 12 ++-- internal/arduino/resources/helpers_test.go | 2 +- internal/cli/configuration/network.go | 17 ++++-- internal/cli/configuration/network_test.go | 7 ++- internal/integrationtest/arduino-cli.go | 6 +- .../integrationtest/daemon/daemon_test.go | 59 +++++++++++++++++++ internal/integrationtest/http_server.go | 2 +- 11 files changed, 115 insertions(+), 37 deletions(-) diff --git a/commands/instances.go b/commands/instances.go index d5534494bae..5b5fe09fcdb 100644 --- a/commands/instances.go +++ b/commands/instances.go @@ -89,7 +89,7 @@ func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateReque } } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(ctx) if err != nil { return nil, err } @@ -377,7 +377,7 @@ func (s *arduinoCoreServerImpl) Init(req *rpc.InitRequest, stream rpc.ArduinoCor responseError(err.GRPCStatus()) continue } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(ctx) if err != nil { taskCallback(&rpc.TaskProgress{Name: i18n.Tr("Error downloading library %s", libraryRef)}) e := &cmderrors.FailedLibraryInstallError{Cause: err} @@ -498,7 +498,7 @@ func (s *arduinoCoreServerImpl) UpdateLibrariesIndex(req *rpc.UpdateLibrariesInd } // Perform index update - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(stream.Context()) if err != nil { return err } @@ -608,7 +608,7 @@ func (s *arduinoCoreServerImpl) UpdateIndex(req *rpc.UpdateIndexRequest, stream } } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(stream.Context()) if err != nil { downloadCB.Start(u, i18n.Tr("Downloading index: %s", filepath.Base(URL.Path))) downloadCB.End(false, i18n.Tr("Invalid network configuration: %s", err)) diff --git a/commands/service_board_identify.go b/commands/service_board_identify.go index 2f4b1f54dce..787de81cee3 100644 --- a/commands/service_board_identify.go +++ b/commands/service_board_identify.go @@ -48,7 +48,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar defer release() props := properties.NewFromHashmap(req.GetProperties()) - res, err := identify(pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection()) + res, err := identify(ctx, pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection()) if err != nil { return nil, err } @@ -58,7 +58,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar } // identify returns a list of boards checking first the installed platforms or the Cloud API -func identify(pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) { +func identify(ctx context.Context, pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) { if properties == nil { return nil, nil } @@ -90,7 +90,7 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings // if installed cores didn't recognize the board, try querying // the builder API if the board is a USB device port if len(boards) == 0 && !skipCloudAPI && !settings.SkipCloudApiForBoardDetection() { - items, err := identifyViaCloudAPI(properties, settings) + items, err := identifyViaCloudAPI(ctx, properties, settings) if err != nil { // this is bad, but keep going logrus.WithError(err).Debug("Error querying builder API") @@ -119,14 +119,14 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings return boards, nil } -func identifyViaCloudAPI(props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func identifyViaCloudAPI(ctx context.Context, props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { // If the port is not USB do not try identification via cloud if !props.ContainsKey("vid") || !props.ContainsKey("pid") { return nil, nil } logrus.Debug("Querying builder API for board identification...") - return cachedAPIByVidPid(props.Get("vid"), props.Get("pid"), settings) + return cachedAPIByVidPid(ctx, props.Get("vid"), props.Get("pid"), settings) } var ( @@ -134,7 +134,7 @@ var ( validVidPid = regexp.MustCompile(`0[xX][a-fA-F\d]{4}`) ) -func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func cachedAPIByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { var resp []*rpc.BoardListItem cacheKey := fmt.Sprintf("cache.builder-api.v3/boards/byvid/pid/%s/%s", vid, pid) @@ -148,7 +148,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp } } - resp, err := apiByVidPid(vid, pid, settings) // Perform API requrest + resp, err := apiByVidPid(ctx, vid, pid, settings) // Perform API requrest if err == nil { if cachedResp, err := json.Marshal(resp); err == nil { @@ -160,7 +160,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp return resp, err } -func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func apiByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { // ensure vid and pid are valid before hitting the API if !validVidPid.MatchString(vid) { return nil, errors.New(i18n.Tr("Invalid vid value: '%s'", vid)) @@ -173,7 +173,7 @@ func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.Boar req, _ := http.NewRequest("GET", url, nil) req.Header.Set("Content-Type", "application/json") - httpClient, err := settings.NewHttpClient() + httpClient, err := settings.NewHttpClient(ctx) if err != nil { return nil, fmt.Errorf("%s: %w", i18n.Tr("failed to initialize http client"), err) } diff --git a/commands/service_board_identify_test.go b/commands/service_board_identify_test.go index 98dc8e40278..31687359885 100644 --- a/commands/service_board_identify_test.go +++ b/commands/service_board_identify_test.go @@ -16,6 +16,7 @@ package commands import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -48,7 +49,7 @@ func TestGetByVidPid(t *testing.T) { vidPidURL = ts.URL settings := configuration.NewSettings() - res, err := apiByVidPid("0xf420", "0XF069", settings) + res, err := apiByVidPid(context.Background(), "0xf420", "0XF069", settings) require.Nil(t, err) require.Len(t, res, 1) require.Equal(t, "Arduino/Genuino MKR1000", res[0].GetName()) @@ -56,7 +57,7 @@ func TestGetByVidPid(t *testing.T) { // wrong vid (too long), wrong pid (not an hex value) - _, err = apiByVidPid("0xfffff", "0xDEFG", settings) + _, err = apiByVidPid(context.Background(), "0xfffff", "0xDEFG", settings) require.NotNil(t, err) } @@ -69,7 +70,7 @@ func TestGetByVidPidNotFound(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NoError(t, err) require.Empty(t, res) } @@ -84,7 +85,7 @@ func TestGetByVidPid5xx(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NotNil(t, err) require.Equal(t, "the server responded with status 500 Internal Server Error", err.Error()) require.Len(t, res, 0) @@ -99,7 +100,7 @@ func TestGetByVidPidMalformedResponse(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NotNil(t, err) require.Equal(t, "wrong format in server response", err.Error()) require.Len(t, res, 0) @@ -107,7 +108,7 @@ func TestGetByVidPidMalformedResponse(t *testing.T) { func TestBoardDetectionViaAPIWithNonUSBPort(t *testing.T) { settings := configuration.NewSettings() - items, err := identifyViaCloudAPI(properties.NewMap(), settings) + items, err := identifyViaCloudAPI(context.Background(), properties.NewMap(), settings) require.NoError(t, err) require.Empty(t, items) } @@ -156,7 +157,7 @@ func TestBoardIdentifySorting(t *testing.T) { defer release() settings := configuration.NewSettings() - res, err := identify(pme, idPrefs, settings, true) + res, err := identify(context.Background(), pme, idPrefs, settings, true) require.NoError(t, err) require.NotNil(t, res) require.Len(t, res, 4) diff --git a/commands/service_check_for_updates.go b/commands/service_check_for_updates.go index 2b5c7a51b4e..cce5a4a02a1 100644 --- a/commands/service_check_for_updates.go +++ b/commands/service_check_for_updates.go @@ -43,7 +43,7 @@ func (s *arduinoCoreServerImpl) CheckForArduinoCLIUpdates(ctx context.Context, r inventory.WriteStore() }() - latestVersion, err := semver.Parse(s.getLatestRelease()) + latestVersion, err := semver.Parse(s.getLatestRelease(ctx)) if err != nil { return nil, err } @@ -82,8 +82,8 @@ func (s *arduinoCoreServerImpl) shouldCheckForUpdate(currentVersion *semver.Vers // getLatestRelease queries the official Arduino download server for the latest release, // if there are no errors or issues a version string is returned, in all other case an empty string. -func (s *arduinoCoreServerImpl) getLatestRelease() string { - client, err := s.settings.NewHttpClient() +func (s *arduinoCoreServerImpl) getLatestRelease(ctx context.Context) string { + client, err := s.settings.NewHttpClient(ctx) if err != nil { return "" } diff --git a/commands/service_library_download.go b/commands/service_library_download.go index 2384d59396f..4253be8cca1 100644 --- a/commands/service_library_download.go +++ b/commands/service_library_download.go @@ -82,11 +82,15 @@ func (s *arduinoCoreServerImpl) LibraryDownload(req *rpc.LibraryDownloadRequest, }) } -func downloadLibrary(ctx context.Context, downloadsDir *paths.Path, libRelease *librariesindex.Release, - downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, queryParameter string, settings *configuration.Settings) error { - +func downloadLibrary( + ctx context.Context, + downloadsDir *paths.Path, libRelease *librariesindex.Release, + downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, + queryParameter string, + settings *configuration.Settings, +) error { taskCB(&rpc.TaskProgress{Name: i18n.Tr("Downloading %s", libRelease)}) - config, err := settings.DownloaderConfig() + config, err := settings.DownloaderConfig(ctx) if err != nil { return &cmderrors.FailedDownloadError{Message: i18n.Tr("Can't download library"), Cause: err} } diff --git a/internal/arduino/resources/helpers_test.go b/internal/arduino/resources/helpers_test.go index 611de8dd518..ad1d6805254 100644 --- a/internal/arduino/resources/helpers_test.go +++ b/internal/arduino/resources/helpers_test.go @@ -55,7 +55,7 @@ func TestDownloadApplyUserAgentHeaderUsingConfig(t *testing.T) { settings := configuration.NewSettings() settings.Set("network.user_agent_ext", goldUserAgentValue) - config, err := settings.DownloaderConfig() + config, err := settings.DownloaderConfig(context.Background()) require.NoError(t, err) err = r.Download(context.Background(), tmp, config, "", func(progress *rpc.DownloadProgress) {}, "") require.NoError(t, err) diff --git a/internal/cli/configuration/network.go b/internal/cli/configuration/network.go index c570d0a3b82..43b502a03fb 100644 --- a/internal/cli/configuration/network.go +++ b/internal/cli/configuration/network.go @@ -16,18 +16,21 @@ package configuration import ( + "context" "errors" "fmt" "net/http" "net/url" "os" "runtime" + "strings" "time" "github.com/arduino/arduino-cli/commands/cmderrors" "github.com/arduino/arduino-cli/internal/i18n" "github.com/arduino/arduino-cli/internal/version" "go.bug.st/downloader/v2" + "google.golang.org/grpc/metadata" ) // UserAgent returns the user agent (mainly used by HTTP clients) @@ -84,17 +87,23 @@ func (settings *Settings) NetworkProxy() (*url.URL, error) { } // NewHttpClient returns a new http client for use in the arduino-cli -func (settings *Settings) NewHttpClient() (*http.Client, error) { +func (settings *Settings) NewHttpClient(ctx context.Context) (*http.Client, error) { proxy, err := settings.NetworkProxy() if err != nil { return nil, err } + userAgent := settings.UserAgent() + if md, ok := metadata.FromIncomingContext(ctx); ok { + if extraUserAgent := strings.Join(md.Get("user-agent"), " "); extraUserAgent != "" { + userAgent += " " + extraUserAgent + } + } return &http.Client{ Transport: &httpClientRoundTripper{ transport: &http.Transport{ Proxy: http.ProxyURL(proxy), }, - userAgent: settings.UserAgent(), + userAgent: userAgent, }, Timeout: settings.ConnectionTimeout(), }, nil @@ -111,8 +120,8 @@ func (h *httpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, e } // DownloaderConfig returns the downloader configuration based on current settings. -func (settings *Settings) DownloaderConfig() (downloader.Config, error) { - httpClient, err := settings.NewHttpClient() +func (settings *Settings) DownloaderConfig(ctx context.Context) (downloader.Config, error) { + httpClient, err := settings.NewHttpClient(ctx) if err != nil { return downloader.Config{}, &cmderrors.InvalidArgumentError{ Message: i18n.Tr("Could not connect via HTTP"), diff --git a/internal/cli/configuration/network_test.go b/internal/cli/configuration/network_test.go index 68cc84fdd59..563c0414589 100644 --- a/internal/cli/configuration/network_test.go +++ b/internal/cli/configuration/network_test.go @@ -16,6 +16,7 @@ package configuration_test import ( + "context" "fmt" "io" "net/http" @@ -35,7 +36,7 @@ func TestUserAgentHeader(t *testing.T) { settings := configuration.NewSettings() require.NoError(t, settings.Set("network.user_agent_ext", "test-user-agent")) - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", ts.URL, nil) @@ -59,7 +60,7 @@ func TestProxy(t *testing.T) { settings := configuration.NewSettings() settings.Set("network.proxy", ts.URL) - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", "http://arduino.cc", nil) @@ -83,7 +84,7 @@ func TestConnectionTimeout(t *testing.T) { if timeout != 0 { require.NoError(t, settings.Set("network.connection_timeout", "2s")) } - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", "http://arduino.cc", nil) diff --git a/internal/integrationtest/arduino-cli.go b/internal/integrationtest/arduino-cli.go index c157a57d7a2..5236449d09a 100644 --- a/internal/integrationtest/arduino-cli.go +++ b/internal/integrationtest/arduino-cli.go @@ -450,7 +450,11 @@ func (cli *ArduinoCLI) StartDaemon(verbose bool) string { for retries := 5; retries > 0; retries-- { time.Sleep(time.Second) - conn, err := grpc.NewClient(cli.daemonAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.NewClient( + cli.daemonAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUserAgent("cli-test/0.0.0"), + ) if err != nil { connErr = err continue diff --git a/internal/integrationtest/daemon/daemon_test.go b/internal/integrationtest/daemon/daemon_test.go index 47225784361..c90c5c7ac61 100644 --- a/internal/integrationtest/daemon/daemon_test.go +++ b/internal/integrationtest/daemon/daemon_test.go @@ -20,6 +20,10 @@ import ( "errors" "fmt" "io" + "maps" + "net/http" + "net/http/httptest" + "strings" "testing" "time" @@ -555,6 +559,61 @@ func TestDaemonCoreUpgradePlatform(t *testing.T) { }) } +func TestDaemonUserAgent(t *testing.T) { + env, cli := integrationtest.CreateEnvForDaemon(t) + defer env.CleanUp() + + // Set up an http server to serve our custom index file + // The user-agent is tested inside the HTTPServeFile function + test_index := paths.New("..", "testdata", "test_index.json") + url := env.HTTPServeFile(8000, test_index) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Test that the user-agent contains metadata from the context when the CLI is in daemon mode + userAgent := r.Header.Get("User-Agent") + + require.Contains(t, userAgent, "cli-test/0.0.0") + require.Contains(t, userAgent, "grpc-go") + // Depends on how we built the client we may have git-snapshot or 0.0.0-git in dev releases + require.Condition(t, func() (success bool) { + return strings.Contains(userAgent, "arduino-cli/git-snapshot") || + strings.Contains(userAgent, "arduino-cli/0.0.0-git") + }) + + proxiedReq, err := http.NewRequest(r.Method, url.String(), r.Body) + require.NoError(t, err) + maps.Copy(proxiedReq.Header, r.Header) + + proxiedResp, err := http.DefaultTransport.RoundTrip(proxiedReq) + require.NoError(t, err) + defer proxiedResp.Body.Close() + + // Copy the headers from the proxy response to the original response + maps.Copy(r.Header, proxiedReq.Header) + w.WriteHeader(proxiedResp.StatusCode) + io.Copy(w, proxiedResp.Body) + })) + defer ts.Close() + + grpcInst := cli.Create() + require.NoError(t, grpcInst.Init("", "", func(ir *commands.InitResponse) { + fmt.Printf("INIT> %v\n", ir.GetMessage()) + })) + + // Set extra indexes + additionalURL := ts.URL + "/test_index.json" + err := cli.SetValue("board_manager.additional_urls", fmt.Sprintf(`["%s"]`, additionalURL)) + require.NoError(t, err) + + { + cl, err := grpcInst.UpdateIndex(context.Background(), false) + require.NoError(t, err) + res, err := analyzeUpdateIndexClient(t, cl) + require.NoError(t, err) + require.Len(t, res, 2) + require.True(t, res[additionalURL].GetSuccess()) + } +} + func analyzeUpdateIndexClient(t *testing.T, cl commands.ArduinoCoreService_UpdateIndexClient) (map[string]*commands.DownloadProgressEnd, error) { analyzer := NewDownloadProgressAnalyzer(t) for { diff --git a/internal/integrationtest/http_server.go b/internal/integrationtest/http_server.go index c5f06e9557a..dc6fd98e099 100644 --- a/internal/integrationtest/http_server.go +++ b/internal/integrationtest/http_server.go @@ -27,6 +27,7 @@ import ( // HTTPServeFile spawn an http server that serve a single file. The server // is started on the given port. The URL to the file and a cleanup function are returned. func (env *Environment) HTTPServeFile(port uint16, path *paths.Path) *url.URL { + t := env.T() mux := http.NewServeMux() mux.HandleFunc("/"+path.Base(), func(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, path.String()) @@ -36,7 +37,6 @@ func (env *Environment) HTTPServeFile(port uint16, path *paths.Path) *url.URL { Handler: mux, } - t := env.T() fileURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d/%s", port, path.Base())) require.NoError(t, err)