Skip to content

Commit

Permalink
Refactor Scan Pull Request to accept PR ID as input (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
EyalDelarea authored Jul 26, 2023
1 parent 16d358e commit 25d5521
Show file tree
Hide file tree
Showing 199 changed files with 114 additions and 536 deletions.
92 changes: 38 additions & 54 deletions commands/scanpullrequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/go-git/go-git/v5"
"github.com/jfrog/gofrog/datastructures"
"os"
"os/exec"
Expand All @@ -27,58 +26,43 @@ const (
noGitHubEnvReviewersErr = "frogbot did not scan this PR, because the existing GitHub Environment named 'frogbot' doesn't have reviewers selected. Please refer to the Frogbot documentation for instructions on how to create the Environment"
)

type ScanPullRequestCmd struct{}
type ScanPullRequestCmd struct {
// Optional provided pull request details, used in scan-pull-requests command.
pullRequestDetails vcsclient.PullRequestInfo
}

// Run ScanPullRequest method only works for a single repository scan.
// Therefore, the first repository config represents the repository on which Frogbot runs, and it is the only one that matters.
func (cmd *ScanPullRequestCmd) Run(configAggregator utils.RepoAggregator, client vcsclient.VcsClient) error {
if err := utils.ValidateSingleRepoConfiguration(&configAggregator); err != nil {
return err
func (cmd *ScanPullRequestCmd) Run(configAggregator utils.RepoAggregator, client vcsclient.VcsClient) (err error) {
if err = utils.ValidateSingleRepoConfiguration(&configAggregator); err != nil {
return
}
repoConfig := &(configAggregator)[0]
if repoConfig.GitProvider == vcsutils.GitHub {
if err := verifyGitHubFrogbotEnvironment(client, repoConfig); err != nil {
return err
if err = verifyGitHubFrogbotEnvironment(client, repoConfig); err != nil {
return
}
}
if err := cmd.verifyDifferentBranches(repoConfig); err != nil {
return err
}
return scanPullRequest(repoConfig, client)
}

// Verifies current branch and target branch are not the same.
// The Current branch is the branch the action is triggered on.
// The Target branch is the branch to open pull request to.
func (cmd *ScanPullRequestCmd) verifyDifferentBranches(repoConfig *utils.Repository) error {
repo, err := git.PlainOpen(".")
if err != nil {
return err
}
ref, err := repo.Head()
if err != nil {
return err
}
currentBranch := ref.Name().Short()
defaultBranch := repoConfig.Branches[0]
if currentBranch == defaultBranch {
return fmt.Errorf(utils.ErrScanPullRequestSameBranches, currentBranch)
// PullRequestDetails can be defined already when using the scan-all-pull-requests command.
if cmd.pullRequestDetails.ID == utils.UndefinedPrID {
if cmd.pullRequestDetails, err = client.GetPullRequestByID(context.Background(), repoConfig.RepoOwner, repoConfig.RepoName, repoConfig.PullRequestID); err != nil {
return
}
}
return nil

return scanPullRequest(repoConfig, client, cmd.pullRequestDetails)
}

// By default, includeAllVulnerabilities is set to false and the scan goes as follows:
// a. Audit the dependencies of the source and the target branches.
// b. Compare the vulnerabilities found in source and target branches, and show only the new vulnerabilities added by the pull request.
// Otherwise, only the source branch is scanned and all found vulnerabilities are being displayed.
func scanPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient) error {
// Validate scan params
if len(repoConfig.Branches) == 0 {
return &utils.ErrMissingEnv{VariableName: utils.GitBaseBranchEnv}
}

func scanPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient, pullRequestDetails vcsclient.PullRequestInfo) error {
log.Info("Scanning Pull Request ID:", pullRequestDetails.ID, "Source:", pullRequestDetails.Source.Name, "Target:", pullRequestDetails.Target.Name)
log.Info("-----------------------------------------------------------")
// Audit PR code
vulnerabilitiesRows, iacRows, err := auditPullRequest(repoConfig, client)
vulnerabilitiesRows, iacRows, err := auditPullRequest(repoConfig, client, pullRequestDetails)
if err != nil {
return err
}
Expand All @@ -98,17 +82,23 @@ func scanPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient) e
return err
}

func auditPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient) ([]formats.VulnerabilityOrViolationRow, []formats.IacSecretsRow, error) {
// Downloads Pull Requests branches code and audits them
func auditPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient, pullRequestDetails vcsclient.PullRequestInfo) ([]formats.VulnerabilityOrViolationRow, []formats.IacSecretsRow, error) {
var vulnerabilitiesRows []formats.VulnerabilityOrViolationRow
var iacRows []formats.IacSecretsRow
targetBranch := repoConfig.Branches[0]
targetBranch := pullRequestDetails.Target.Name
sourceBranch := pullRequestDetails.Source.Name
for i := range repoConfig.Projects {
// Source scan details
scanDetails := utils.NewScanDetails(client, &repoConfig.Server, &repoConfig.Git).
SetProject(&repoConfig.Projects[i]).
SetXrayGraphScanParams(repoConfig.Watches, repoConfig.JFrogProjectKey).
SetMinSeverity(repoConfig.MinSeverity).
SetFixableOnly(repoConfig.FixableOnly)
sourceResults, err := auditSource(scanDetails)
SetFixableOnly(repoConfig.FixableOnly).
SetBranch(sourceBranch).
SetRepoOwner(pullRequestDetails.Source.Owner)

sourceResults, err := downloadAndAuditBranch(scanDetails)
if err != nil {
return nil, nil, err
}
Expand All @@ -123,9 +113,12 @@ func auditPullRequest(repoConfig *utils.Repository, client vcsclient.VcsClient)
iacRows = append(iacRows, xrayutils.PrepareIacs(sourceResults.ExtendedScanResults.IacScanResults)...)
continue
}
// Audit target code
scanDetails.SetFailOnInstallationErrors(*repoConfig.FailOnSecurityIssues).SetBranch(targetBranch)
targetResults, err := auditTarget(scanDetails)
// Target scan details
scanDetails.SetFailOnInstallationErrors(*repoConfig.FailOnSecurityIssues).
SetBranch(targetBranch).
SetRepoOwner(pullRequestDetails.Target.Owner)

targetResults, err := downloadAndAuditBranch(scanDetails)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -235,15 +228,6 @@ func getScanVulnerabilitiesRows(auditResults *audit.Results) ([]formats.Vulnerab
return []formats.VulnerabilityOrViolationRow{}, nil
}

func auditSource(scanSetup *utils.ScanDetails) (auditResults *audit.Results, err error) {
wd, err := os.Getwd()
if err != nil {
return
}
fullPathWds := getFullPathWorkingDirs(scanSetup.WorkingDirs, wd)
return runInstallAndAudit(scanSetup, fullPathWds...)
}

func getFullPathWorkingDirs(workingDirs []string, baseWd string) []string {
var fullPathWds []string
if len(workingDirs) != 0 {
Expand All @@ -260,9 +244,9 @@ func getFullPathWorkingDirs(workingDirs []string, baseWd string) []string {
return fullPathWds
}

func auditTarget(scanSetup *utils.ScanDetails) (auditResults *audit.Results, err error) {
func downloadAndAuditBranch(scanSetup *utils.ScanDetails) (auditResults *audit.Results, err error) {
// First download the target repo to temp dir
log.Info("Auditing the", scanSetup.Git.RepoName, "repository on the", scanSetup.Branch(), "branch")
log.Info("Auditing repository:", scanSetup.Git.RepoName, "branch:", scanSetup.Branch())
wd, cleanup, err := utils.DownloadRepoToTempDir(scanSetup.Client(), scanSetup.Branch(), scanSetup.Git)
if err != nil {
return
Expand Down
65 changes: 27 additions & 38 deletions commands/scanpullrequest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ const (
testCleanProjConfigPath = "testdata/config/frogbot-config-clean-test-proj.yml"
testProjConfigPath = "testdata/config/frogbot-config-test-proj.yml"
testProjConfigPathNoFail = "testdata/config/frogbot-config-test-proj-no-fail.yml"
testSameBranchProjConfigPath = "testdata/config/frogbot-config-test-same-branch-fail.yml"
testSourceBranchName = "pr"
testTargetBranchName = "master"
)

func TestCreateVulnerabilitiesRows(t *testing.T) {
Expand Down Expand Up @@ -504,32 +505,6 @@ func TestScanPullRequest(t *testing.T) {
testScanPullRequest(t, testProjConfigPath, "test-proj", true)
}

func TestScanPullRequestSameBranchFail(t *testing.T) {
params, restoreEnv := verifyEnv(t)
defer restoreEnv()

// Create mock GitLab server
projectName := "test-same-branch-fail"

server := httptest.NewServer(createGitLabHandler(t, projectName))
defer server.Close()

configAggregator, client := prepareConfigAndClient(t, testSameBranchProjConfigPath, server, params)
_, cleanUp := utils.PrepareTestEnvironment(t, projectName, "scanpullrequest")
defer cleanUp()

// Run "frogbot scan pull request"
var scanPullRequest ScanPullRequestCmd
err := scanPullRequest.Run(configAggregator, client)
exceptedError := fmt.Errorf(utils.ErrScanPullRequestSameBranches, "main")
assert.Equal(t, exceptedError, err)

// Check env sanitize
err = utils.SanitizeEnv()
assert.NoError(t, err)
utils.AssertSanitizedEnv(t)
}

func TestScanPullRequestNoFail(t *testing.T) {
testScanPullRequest(t, testProjConfigPathNoFail, "test-proj", false)
}
Expand Down Expand Up @@ -663,30 +638,43 @@ func TestScanPullRequestError(t *testing.T) {
// Create HTTP handler to mock GitLab server
func createGitLabHandler(t *testing.T, projectName string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.RequestURI {
// Return 200 on ping
if r.RequestURI == "/api/v4/" {
case "/api/v4/":
w.WriteHeader(http.StatusOK)
return
}

// Return test-proj.tar.gz when using DownloadRepository
if r.RequestURI == fmt.Sprintf("/api/v4/projects/jfrog%s/repository/archive.tar.gz?sha=master", "%2F"+projectName) {
// Mimic get pull request by ID
case fmt.Sprintf("/api/v4/projects/jfrog%s/merge_requests/1", "%2F"+projectName):
w.WriteHeader(http.StatusOK)
repoFile, err := os.ReadFile(filepath.Join("..", projectName+".tar.gz"))
expectedResponse, err := os.ReadFile(filepath.Join("..", "expectedPullRequestDetailsResponse.json"))
assert.NoError(t, err)
_, err = w.Write(expectedResponse)
assert.NoError(t, err)
return
// Mimic download specific branch to scan
case fmt.Sprintf("/api/v4/projects/jfrog%s/repository/archive.tar.gz?sha=%s", "%2F"+projectName, testSourceBranchName):
w.WriteHeader(http.StatusOK)
repoFile, err := os.ReadFile(filepath.Join("..", projectName, "sourceBranch.gz"))
assert.NoError(t, err)
_, err = w.Write(repoFile)
assert.NoError(t, err)
}
return
// Download repository mock
case fmt.Sprintf("/api/v4/projects/jfrog%s/repository/archive.tar.gz?sha=%s", "%2F"+projectName, testTargetBranchName):
w.WriteHeader(http.StatusOK)
repoFile, err := os.ReadFile(filepath.Join("..", projectName, "targetBranch.gz"))
assert.NoError(t, err)
_, err = w.Write(repoFile)
assert.NoError(t, err)
return
// clean-test-proj should not include any vulnerabilities so assertion is not needed.
if r.RequestURI == fmt.Sprintf("/api/v4/projects/jfrog%s/merge_requests/1/notes", "%2Fclean-test-proj") {
case fmt.Sprintf("/api/v4/projects/jfrog%s/merge_requests/1/notes", "%2Fclean-test-proj"):
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("{}"))
assert.NoError(t, err)
return
}

// Return 200 when using the REST that creates the comment
if r.RequestURI == fmt.Sprintf("/api/v4/projects/jfrog%s/merge_requests/1/notes", "%2F"+projectName) {
case fmt.Sprintf("/api/v4/projects/jfrog%s/merge_requests/1/notes", "%2F"+projectName):
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(r.Body)
assert.NoError(t, err)
Expand All @@ -707,6 +695,7 @@ func createGitLabHandler(t *testing.T, projectName string) http.HandlerFunc {
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte("{}"))
assert.NoError(t, err)
return
}
}
}
Expand Down
78 changes: 9 additions & 69 deletions commands/scanpullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ type ScanAllPullRequestsCmd struct {

func (cmd ScanAllPullRequestsCmd) Run(configAggregator utils.RepoAggregator, client vcsclient.VcsClient) error {
for _, config := range configAggregator {
log.Info("Scanning all open pull requests for repository:", config.RepoName)
log.Info("-----------------------------------------------------------")
err := scanAllPullRequests(config, client)
if err != nil {
return err
}
}

return nil
}

Expand All @@ -43,14 +44,14 @@ func scanAllPullRequests(repo utils.Repository, client vcsclient.VcsClient) (err
if e != nil {
err = errors.Join(err, fmt.Errorf(errPullRequestScan, int(pr.ID), repo.RepoName, e.Error()))
}
if shouldScan {
e = downloadAndScanPullRequest(pr, repo, client)
// If error, write it in errList and continue to the next PR.
if e != nil {
err = errors.Join(err, fmt.Errorf(errPullRequestScan, int(pr.ID), repo.RepoName, e.Error()))
}
} else {
if !shouldScan {
log.Info("Pull Request", pr.ID, "has already been scanned before. If you wish to scan it again, please comment \"rescan\".")
return
}
spr := &ScanPullRequestCmd{pullRequestDetails: pr}
if e = spr.Run(utils.RepoAggregator{repo}, client); e != nil {
// If error, write it in errList and continue to the next PR.
err = errors.Join(err, fmt.Errorf(errPullRequestScan, int(pr.ID), repo.RepoName, e.Error()))
}
}
return
Expand Down Expand Up @@ -83,64 +84,3 @@ func shouldScanPullRequest(repo utils.Repository, client vcsclient.VcsClient, pr
func isFrogbotRescanComment(comment string) bool {
return strings.Contains(strings.ToLower(strings.TrimSpace(comment)), utils.RescanRequestComment)
}

func downloadAndScanPullRequest(pr vcsclient.PullRequestInfo, repo utils.Repository, client vcsclient.VcsClient) (err error) {
// Download the pull request source ("from") branch
params := utils.Params{
Git: utils.Git{
ClientInfo: utils.ClientInfo{
GitProvider: repo.GitProvider,
VcsInfo: vcsclient.VcsInfo{APIEndpoint: repo.APIEndpoint, Token: repo.Token},
RepoOwner: repo.RepoOwner,
RepoName: pr.Source.Repository,
Branches: []string{pr.Source.Name}},
}}
frogbotParams := &utils.Repository{
Server: repo.Server,
Params: params,
}
wd, cleanup, err := utils.DownloadRepoToTempDir(client, pr.Source.Name, &frogbotParams.Git)
if err != nil {
return err
}
// Cleanup
defer func() {
err = errors.Join(err, cleanup())
}()
restoreDir, err := utils.Chdir(wd)
if err != nil {
return err
}
defer func() {
err = errors.Join(err, restoreDir())
}()
// The target branch (to) will be downloaded as part of the Frogbot scanPullRequest execution
params = utils.Params{
Scan: utils.Scan{
FailOnSecurityIssues: repo.FailOnSecurityIssues,
IncludeAllVulnerabilities: repo.IncludeAllVulnerabilities,
Projects: repo.Projects,
},
Git: utils.Git{
ClientInfo: utils.ClientInfo{
GitProvider: repo.GitProvider,
VcsInfo: vcsclient.VcsInfo{APIEndpoint: repo.APIEndpoint, Token: repo.Token},
RepoOwner: repo.RepoOwner,
Branches: []string{pr.Target.Name},
RepoName: pr.Target.Repository,
},
PullRequestID: int(pr.ID),
},
JFrogPlatform: utils.JFrogPlatform{
Watches: repo.Watches,
JFrogProjectKey: repo.JFrogProjectKey,
},
}

frogbotParams = &utils.Repository{
OutputWriter: utils.GetCompatibleOutputWriter(repo.GitProvider),
Server: repo.Server,
Params: params,
}
return scanPullRequest(frogbotParams, client)
}
Loading

0 comments on commit 25d5521

Please sign in to comment.