Skip to content

Commit

Permalink
support add cache for credscan index
Browse files Browse the repository at this point in the history
  • Loading branch information
teowa committed Apr 15, 2024
1 parent 6330b96 commit e58a82f
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ armstrong.exe

# test output
coverage/test_coverage_report*.md

coverage/testdata/index.json
27 changes: 20 additions & 7 deletions commands/credential_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@ import (
)

type CredentialScanCommand struct {
workingDir string
swaggerRepoPath string
verbose bool
workingDir string
swaggerRepoPath string
swaggerIndexFile string
verbose bool
}

func (c *CredentialScanCommand) flags() *flag.FlagSet {
fs := defaultFlagSet("test")
fs.BoolVar(&c.verbose, "v", false, "whether show terraform logs")
fs.StringVar(&c.workingDir, "working-dir", "", "path to Terraform configuration files")
fs.StringVar(&c.workingDir, "working-dir", "", "path to directory containing Terraform configuration files")
fs.StringVar(&c.swaggerRepoPath, "swagger-repo", "", "path to the swagger repo specification directory")
fs.StringVar(&c.swaggerIndexFile, "swagger-index-file", "", "path to the swagger index file, omit this will use the online swagger index file or locally build index")
fs.Usage = func() { logrus.Error(c.Help()) }
return fs
}

func (c CredentialScanCommand) Help() string {
helpText := `
Usage: armstrong credscan [-v] [-working-dir <path to directory containing Terraform configuration files>] [-swagger-repo <path to the swagger repo specification directory>]
Usage: armstrong credscan [-v] [-working-dir <path to directory containing Terraform configuration files>] [-swagger-repo <path to the swagger repo specification directory>] [-swagger-index-file <path to the swagger index file>]
` + c.Synopsis() + "\n\n" + helpForFlags(c.flags())

return strings.TrimSpace(helpText)
Expand Down Expand Up @@ -91,6 +93,17 @@ func (c CredentialScanCommand) Execute() int {

c.swaggerRepoPath += "/"
}
if c.swaggerIndexFile != "" {
c.swaggerIndexFile, err = filepath.Abs(c.swaggerIndexFile)
if err != nil {
logrus.Errorf("swagger index file path %q is invalid: %+v", c.swaggerIndexFile, err)
return 1
}

if _, err := os.Stat(c.swaggerIndexFile); os.IsNotExist(err) {
logrus.Infof("swagger index file %q does not exist, will try to build or download index", c.swaggerIndexFile)
}
}

tfFiles, err := hcl.FindTfFiles(wd)
if err != nil {
Expand Down Expand Up @@ -210,7 +223,7 @@ func (c CredentialScanCommand) Execute() int {
var swaggerModel *coverage.SwaggerModel
if c.swaggerRepoPath != "" {
logrus.Infof("scan based on local swagger repo: %s", c.swaggerRepoPath)
swaggerModel, err = coverage.GetModelInfoFromLocalIndex(mockedResourceId, apiVersion, c.swaggerRepoPath)
swaggerModel, err = coverage.GetModelInfoFromLocalIndex(mockedResourceId, apiVersion, c.swaggerRepoPath, c.swaggerIndexFile)
if err != nil {
credScanErr := makeCredScanError(
azapiResource,
Expand All @@ -223,7 +236,7 @@ func (c CredentialScanCommand) Execute() int {
continue
}
} else {
swaggerModel, err = coverage.GetModelInfoFromIndex(mockedResourceId, apiVersion)
swaggerModel, err = coverage.GetModelInfoFromIndex(mockedResourceId, apiVersion, c.swaggerIndexFile)
if err != nil {
credScanErr := makeCredScanError(
azapiResource,
Expand Down
7 changes: 7 additions & 0 deletions coverage/coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"os"
"path"
"strings"
"testing"

"github.com/azure/armstrong/coverage"
Expand All @@ -21,6 +22,10 @@ type testCase struct {
resourceType string
}

func normarlizePath(path string) string {
return strings.ReplaceAll(path, string(os.PathSeparator), "/")
}

func TestCoverage_ResourceGroup(t *testing.T) {
tc := testCase{
name: "ResourceGroup",
Expand Down Expand Up @@ -1402,6 +1407,7 @@ func testCoverage(t *testing.T, tc testCase) (*coverage.Model, error) {
swaggerModel, err := coverage.GetModelInfoFromIndex(
tc.apiPath,
tc.apiVersion,
"",
)

t.Logf("swaggerModel: %+v", swaggerModel)
Expand Down Expand Up @@ -1477,6 +1483,7 @@ func testCredScan(t *testing.T, tc testCase) (*map[string]string, error) {
swaggerModel, err := coverage.GetModelInfoFromIndex(
tc.apiPath,
tc.apiVersion,
"",
)

t.Logf("swaggerModel: %+v", swaggerModel)
Expand Down
3 changes: 1 addition & 2 deletions coverage/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package coverage

import (
"fmt"
"net/url"
"path/filepath"
"strings"

Expand Down Expand Up @@ -327,7 +326,7 @@ func SchemaNamePathFromRef(swaggerPath string, ref openapiSpec.Ref) (schemaName
schemaPath = swaggerPath
} else {
swaggerPath, _ := filepath.Split(swaggerPath)
schemaPath, _ = url.JoinPath(swaggerPath, schemaPath)
schemaPath = swaggerPath + schemaPath
}

fragments := strings.Split(refUrl.Fragment, "/")
Expand Down
4 changes: 2 additions & 2 deletions coverage/expand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestExpandAll(t *testing.T) {
if strings.HasSuffix(azureRepoDir, "specification") {
azureRepoDir += "/"
}
if !strings.HasSuffix(azureRepoDir, "specification/") {
if !strings.HasSuffix(normarlizePath(azureRepoDir), "specification/") {
t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/\"")
}

Expand Down Expand Up @@ -207,7 +207,7 @@ func testExpandAll(t *testing.T, azureRepoDir, testResultPath string) result {
}

func getRefList(t *testing.T) []jsonreference.Ref {
index, err := coverage.GetIndex()
index, err := coverage.GetIndex("")
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions coverage/from_local_spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func GetModelInfoFromLocalSpecFile(resourceId, apiVersion, swaggerPath string) (
}

for pathKey, pathItem := range paths.Paths {
if !isPathKeyMatchWithResourceId(pathKey, resourceId) {
if !IsPathKeyMatchWithResourceId(pathKey, resourceId) {
continue
}

Expand Down Expand Up @@ -95,7 +95,7 @@ func GetModelInfoFromLocalSpecFile(resourceId, apiVersion, swaggerPath string) (
return nil, nil
}

func isPathKeyMatchWithResourceId(pathKey, resourceId string) bool {
func IsPathKeyMatchWithResourceId(pathKey, resourceId string) bool {
pathParts := strings.Split(strings.Trim(pathKey, "/"), "/")
resourceIdParts := strings.Split(strings.Trim(resourceId, "/"), "/")
i := len(pathParts) - 1
Expand Down
16 changes: 9 additions & 7 deletions coverage/from_local_spec_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package coverage
package coverage_test

import (
"os"
"path"
"testing"

"github.com/azure/armstrong/coverage"
)

func Test_isPathKeyMatchWithResourceId(t *testing.T) {
Expand Down Expand Up @@ -35,7 +37,7 @@ func Test_isPathKeyMatchWithResourceId(t *testing.T) {
}
for _, testcase := range testcases {
t.Logf("testcase: %+v", testcase)
actual := isPathKeyMatchWithResourceId(testcase.PathKey, testcase.ResourceId)
actual := coverage.IsPathKeyMatchWithResourceId(testcase.PathKey, testcase.ResourceId)
if actual != testcase.Expected {
t.Fatalf("expected %v, got %v", testcase.Expected, actual)
}
Expand All @@ -51,12 +53,12 @@ func Test_GetModelInfoFromLocalDir(t *testing.T) {
testcases := []struct {
ResourceId string
ApiVersion string
Expected *SwaggerModel
Expected *coverage.SwaggerModel
}{
{
ResourceId: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/test-resources/providers/Microsoft.Automation/automationAccounts/test-automation-account",
ApiVersion: "2022-08-08",
Expected: &SwaggerModel{
Expected: &coverage.SwaggerModel{
ApiPath: "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Automation/automationAccounts/{automationAccountName}",
ModelName: "AutomationAccountCreateOrUpdateParameters",
SwaggerPath: path.Join(swaggerPath, "account.json"),
Expand All @@ -65,7 +67,7 @@ func Test_GetModelInfoFromLocalDir(t *testing.T) {
{
ResourceId: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/test-resources/providers/Microsoft.Automation/automationAccounts/test-automation-account/certificates/test-certificate",
ApiVersion: "2022-08-08",
Expected: &SwaggerModel{
Expected: &coverage.SwaggerModel{
ApiPath: "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Automation/automationAccounts/{automationAccountName}/certificates/{certificateName}",
ModelName: "CertificateCreateOrUpdateParameters",
SwaggerPath: path.Join(swaggerPath, "certificate.json"),
Expand All @@ -75,7 +77,7 @@ func Test_GetModelInfoFromLocalDir(t *testing.T) {

for _, testcase := range testcases {
t.Logf("testcase: %+v", testcase.ResourceId)
actual, err := GetModelInfoFromLocalDir(testcase.ResourceId, testcase.ApiVersion, swaggerPath)
actual, err := coverage.GetModelInfoFromLocalDir(testcase.ResourceId, testcase.ApiVersion, swaggerPath)
if err != nil {
t.Fatalf("get model info from local dir error: %+v", err)
}
Expand All @@ -88,7 +90,7 @@ func Test_GetModelInfoFromLocalDir(t *testing.T) {
if actual.ModelName != testcase.Expected.ModelName {
t.Fatalf("expected modelName %s, got %s", testcase.Expected.ModelName, actual.ModelName)
}
if actual.SwaggerPath != testcase.Expected.SwaggerPath {
if normarlizePath(actual.SwaggerPath) != normarlizePath(testcase.Expected.SwaggerPath) {
t.Fatalf("expected swaggerPath %s, got %s", testcase.Expected.SwaggerPath, actual.SwaggerPath)
}
}
Expand Down
95 changes: 85 additions & 10 deletions coverage/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,28 @@ const (

var indexCache *azidx.Index

func GetIndexFromLocalDir(swaggerRepo string) (*azidx.Index, error) {
func GetIndexFromLocalDir(swaggerRepo, indexFilePath string) (*azidx.Index, error) {
if indexCache != nil {
return indexCache, nil
}

logrus.Infof("building index from from local swagger %s", swaggerRepo)
if indexFilePath != "" {
if _, err := os.Stat(indexFilePath); err == nil {
byteValue, _ := os.ReadFile(indexFilePath)

var index azidx.Index
if err := json.Unmarshal(byteValue, &index); err != nil {
return nil, fmt.Errorf("unmarshal index file: %+v", err)
}
indexCache = &index

logrus.Infof("load index from cache file %s", indexFilePath)

return indexCache, nil
}
}

logrus.Infof("building index from from local swagger %s, it might take several minutes", swaggerRepo)
index, err := azidx.BuildIndex(swaggerRepo, "")
if err != nil {
logrus.Error(fmt.Sprintf("failed to build index: %+v", err))
Expand All @@ -39,19 +55,53 @@ func GetIndexFromLocalDir(swaggerRepo string) (*azidx.Index, error) {

indexCache = index

if indexFilePath != "" {
jsonBytes, err := json.Marshal(&index)
if err != nil {
logrus.Warningf("failed to marshal index: %+v", err)
return index, nil
}

err = os.WriteFile(indexFilePath, jsonBytes, 0644)
if err != nil {
logrus.Warningf("failed to write index cache file %s: %+v", indexFilePath, err)
return index, nil
}

logrus.Infof("index successfully saved to cache file %s", indexFilePath)
}

return index, nil
}

func GetIndex() (*azidx.Index, error) {
func GetIndex(indexFilePath string) (*azidx.Index, error) {
if indexCache != nil {
return indexCache, nil
}

if indexFilePath != "" {
if _, err := os.Stat(indexFilePath); err == nil {
byteValue, _ := os.ReadFile(indexFilePath)

var index azidx.Index
if err := json.Unmarshal(byteValue, &index); err != nil {
return nil, fmt.Errorf("unmarshal index file: %+v", err)
}
indexCache = &index

logrus.Infof("load index from cache file %s", indexFilePath)

return indexCache, nil
}
}

resp, err := http.Get(indexFileURL)
if err != nil {
return nil, fmt.Errorf("get index file from %v: %+v", indexFileURL, err)
}

logrus.Infof("downloading index file from %s", indexFileURL)

defer resp.Body.Close()

b, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -86,6 +136,23 @@ func GetIndex() (*azidx.Index, error) {
indexCache = &index

logrus.Infof("load index based commit: https://github.com/Azure/azure-rest-api-specs/tree/%s", index.Commit)

if indexFilePath != "" {
jsonBytes, err := json.Marshal(&index)
if err != nil {
logrus.Warningf("failed to marshal index: %+v", err)
return indexCache, nil
}

err = os.WriteFile(indexFilePath, jsonBytes, 0644)
if err != nil {
logrus.Warningf("failed to write index cache file %s: %+v", indexFilePath, err)
return indexCache, nil
}

logrus.Infof("index successfully saved to cache file %s", indexFilePath)
}

return indexCache, nil
}

Expand All @@ -95,8 +162,10 @@ type SwaggerModel struct {
SwaggerPath string
}

func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) {
index, err := GetIndex()
// GetModelInfoFromIndex will try to download online index from https://github.com/teowa/azure-rest-api-index-file, and get model info from it
// if the index is already downloaded as in {indexFilePath}, it will use the cached index
func GetModelInfoFromIndex(resourceId, apiVersion, indexFilePath string) (*SwaggerModel, error) {
index, err := GetIndex(indexFilePath)
if err != nil {
return nil, err
}
Expand All @@ -123,7 +192,7 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error)
}

// GetModelInfoFromLocalIndex tries to build index from local swagger repo and get model info from it
func GetModelInfoFromLocalIndex(resourceId, apiVersion, swaggerRepo string) (*SwaggerModel, error) {
func GetModelInfoFromLocalIndex(resourceId, apiVersion, swaggerRepo, indexCacheFile string) (*SwaggerModel, error) {
swaggerRepo, err := filepath.Abs(swaggerRepo)
if err != nil {
return nil, fmt.Errorf("swagger repo path %q is invalid: %+v", swaggerRepo, err)
Expand All @@ -141,7 +210,7 @@ func GetModelInfoFromLocalIndex(resourceId, apiVersion, swaggerRepo string) (*Sw

swaggerRepo += "/"

index, err := GetIndexFromLocalDir(swaggerRepo)
index, err := GetIndexFromLocalDir(swaggerRepo, indexCacheFile)
if err != nil {
return nil, fmt.Errorf("build index from local dir %s: %+v", swaggerRepo, err)
}
Expand Down Expand Up @@ -170,7 +239,13 @@ func GetModelInfoFromLocalIndex(resourceId, apiVersion, swaggerRepo string) (*Sw
func GetModelInfoFromIndexRef(ref openapispec.Ref, swaggerRepo string) (*SwaggerModel, error) {
_, swaggerPath := SchemaNamePathFromRef(swaggerRepo, ref)

relativeBase := swaggerRepo + strings.Split(ref.GetURL().Path, "/")[0]
seperator := "/"
// in windows the ref might use backslashes
if strings.Contains(ref.GetURL().Path, string(os.PathSeparator)) {
seperator = string(os.PathSeparator)
}

relativeBase := swaggerRepo + strings.Split(ref.GetURL().Path, seperator)[0]
operation, err := openapispec.ResolvePathItemWithBase(nil, ref, &openapispec.ExpandOptions{RelativeBase: relativeBase})
if err != nil {
return nil, err
Expand Down Expand Up @@ -233,7 +308,7 @@ func MockResourceIDFromType(azapiResourceType string) (string, string) {
return fmt.Sprintf("%s/%s/providers/%s/%s", subscritionSeg, resourceGroupSeg, resourceProvider, typeIds), apiVersion
}

func GetModelInfoFromIndexWithType(azapiResourceType string) (*SwaggerModel, error) {
func GetModelInfoFromIndexWithType(azapiResourceType, indexCacheFile string) (*SwaggerModel, error) {
resourceId, apiVersion := MockResourceIDFromType(azapiResourceType)
return GetModelInfoFromIndex(resourceId, apiVersion)
return GetModelInfoFromIndex(resourceId, apiVersion, indexCacheFile)
}
Loading

0 comments on commit e58a82f

Please sign in to comment.