From 3d2573f8d7266056599de7197e07a76e7655bcbc Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Fri, 14 Jul 2023 18:47:15 +0800 Subject: [PATCH 01/10] [wip] fix discriminator resolve --- coverage/coverage.go | 42 +++++--- coverage/coverage_test.go | 207 ++++++++++++++++++++++++++++++++++---- coverage/expand.go | 89 +++++++++------- coverage/expand_test.go | 29 ++++-- coverage/index.go | 29 +++++- coverage/index_test.go | 28 +++++- report/pass_report.go | 22 +++- 7 files changed, 358 insertions(+), 88 deletions(-) diff --git a/coverage/coverage.go b/coverage/coverage.go index ef1bbcfe..953e930d 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -66,33 +66,42 @@ func (m *Model) MarkCovered(root interface{}) { } case map[string]interface{}: - if m.Discriminator != nil { + isMatchProperty := true + if m.Discriminator != nil && m.Variants != nil { for k, v := range value { if k == *m.Discriminator { - if m.Variants == nil { - log.Printf("[ERROR] unexpected discriminator %s in %s\n", k, m.Identifier) - } - if _, ok := (*m.Variants)[v.(string)]; !ok { + variant, ok := (*m.Variants)[v.(string)] + if !ok { log.Printf("[ERROR] unexpected variant %s in %s\n", v.(string), m.Identifier) } + if variant == nil { + break + } + + isMatchProperty = false (*m.Variants)[v.(string)].MarkCovered(value) + break } } } - for k, v := range value { - if m.Properties == nil { - if !m.HasAdditionalProperties && m.Discriminator == nil { - log.Printf("[WARN] unexpected key %s in %s\n", k, m.Identifier) + + if isMatchProperty { + for k, v := range value { + if m.Properties == nil { + if !m.HasAdditionalProperties { + log.Printf("[WARN] unexpected key %s in %s\n", k, m.Identifier) + } + return } - return - } - if _, ok := (*m.Properties)[k]; !ok { - if !m.HasAdditionalProperties && m.Discriminator == nil { - log.Printf("[WARN] unexpected key %s in %s\n", k, m.Identifier) + if _, ok := (*m.Properties)[k]; !ok { + if !m.HasAdditionalProperties { + log.Printf("[WARN] unexpected key %s in %s\n", k, m.Identifier) + return + } } + (*m.Properties)[k].MarkCovered(v) } - (*m.Properties)[k].MarkCovered(v) } case nil: @@ -137,6 +146,9 @@ func (m *Model) CountCoverage() (int, int) { if m.Properties != nil { for _, v := range *m.Properties { + if v.IsReadOnly { + continue + } covered, total := v.CountCoverage() m.CoveredCount += covered m.TotalCount += total diff --git a/coverage/coverage_test.go b/coverage/coverage_test.go index 0aeafde0..f810c82b 100644 --- a/coverage/coverage_test.go +++ b/coverage/coverage_test.go @@ -21,14 +21,48 @@ type testCase struct { resourceType string } -func TestCoverageResourceGroup(t *testing.T) { +func TestCoverage_DataMigrationTasks(t *testing.T) { + tc := testCase{ + name: "DataMigrationTasks", + resourceType: "Microsoft.DataMigration/services/projects/tasks@2021-06-30", + apiVersion: "2021-06-30", + apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/DmsSdkRg/providers/Microsoft.DataMigration/services/DmsSdkService/projects/DmsSdkProject/tasks/DmsSdkTask", + rawRequest: []string{`{ + "taskType": "ConnectToTarget.SqlDb", + "input": { + "targetConnectionInfo": { + "type": "SqlConnectionInfo", + "dataSource": "ssma-test-server.database.windows.net", + "authentication": "SqlAuthentication", + "encryptConnection": true, + "trustServerCertificate": true, + "userName": "testuser", + "password": "testpassword" + } + } +}`, + }, + } + + model, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } + + if model.CoveredCount != 1 { + t.Fatalf("expected CoveredCount 1, got %d", model.CoveredCount) + } +} + +func TestCoverage_ResourceGroup(t *testing.T) { tc := testCase{ name: "ResourceGroup", resourceType: "Microsoft.Resources/resourceGroups@2022-09-01", apiVersion: "2022-09-01", apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/rgName", - rawRequest: []string{ - `{"location": "westeurope"}`, + rawRequest: []string{`{ + "location": "westeurope" +}`, }, } @@ -58,7 +92,40 @@ func TestCoverageResourceGroup(t *testing.T) { } } -func TestCoverageKeyVault(t *testing.T) { +func TestCoverage_DeviceSecurityGroup(t *testing.T) { + tc := testCase{ + name: "DeviceSecurityGroup", + resourceType: "Microsoft.Security/deviceSecurityGroups@2019-08-01", + apiVersion: "2019-08-01", + apiPath: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/SampleRG/providers/Microsoft.Devices/iotHubs/sampleiothub/providers/Microsoft.Security/deviceSecurityGroups/samplesecuritygroup", + rawRequest: []string{`{ + "properties": { + "timeWindowRules": [ + { + "ruleType": "ActiveConnectionsNotInAllowedRange", + "isEnabled": true, + "minThreshold": 0, + "maxThreshold": 30, + "timeWindowSize": "PT05M" + } + ] + } +} +`, + }, + } + + model, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } + + if model.CoveredCount != 5 { + t.Fatalf("expected CoveredCount 5, got %d", model.CoveredCount) + } +} + +func TestCoverage_KeyVault(t *testing.T) { tc := testCase{ name: "KeyVault", resourceType: "Microsoft.KeyVault/vaults@2023-02-01", @@ -144,7 +211,7 @@ func TestCoverageKeyVault(t *testing.T) { } -func TestCoverageStorageAccount(t *testing.T) { +func TestCoverage_StorageAccount(t *testing.T) { tc := testCase{ name: "StorageAccount", resourceType: "Microsoft.Storage/storageAccounts@2022-09-01", @@ -211,7 +278,7 @@ func TestCoverageStorageAccount(t *testing.T) { } } -func TestCoverageVM(t *testing.T) { +func TestCoverage_VM(t *testing.T) { tc := testCase{ name: "VM", resourceType: "Microsoft.Compute/virtualMachines@2023-03-01", @@ -281,7 +348,7 @@ func TestCoverageVM(t *testing.T) { } -func TestCoverageVNet(t *testing.T) { +func TestCoverage_VNet(t *testing.T) { tc := testCase{ name: "VNet", resourceType: "Microsoft.Network/virtualNetworks@2023-02-01", @@ -318,7 +385,7 @@ func TestCoverageVNet(t *testing.T) { } } -func TestCoverageDataCollectionRule(t *testing.T) { +func TestCoverage_DataCollectionRule(t *testing.T) { tc := testCase{ name: "DataCollectionRule", resourceType: "Microsoft.Insights/dataCollectionRules@2022-06-01", @@ -644,7 +711,7 @@ func TestCoverageDataCollectionRule(t *testing.T) { } -func TestCoverageWebSite(t *testing.T) { +func TestCoverage_WebSite(t *testing.T) { tc := testCase{ name: "WebSites", resourceType: "Microsoft.Web/sites@2022-09-01", @@ -671,7 +738,7 @@ func TestCoverageWebSite(t *testing.T) { } -func TestCoverageAKS(t *testing.T) { +func TestCoverage_AKS(t *testing.T) { tc := testCase{ name: "AKS", resourceType: "Microsoft.ContainerService/ManagedClusters@2023-05-02-preview", @@ -761,7 +828,7 @@ func TestCoverageAKS(t *testing.T) { } -func TestCoverageCosmosDB(t *testing.T) { +func TestCoverage_CosmosDB(t *testing.T) { tc := testCase{ name: "CosmosDB", resourceType: "Microsoft.DocumentDB/databaseAccounts@2023-04-15", @@ -854,8 +921,8 @@ func TestCoverageCosmosDB(t *testing.T) { t.Fatalf("process coverage: %v", err) } - if model.CoveredCount != 34 { - t.Fatalf("expected CoveredCount 34, got %d", model.CoveredCount) + if model.CoveredCount != 33 { + t.Fatalf("expected CoveredCount 33, got %d", model.CoveredCount) } if model.Properties == nil { @@ -988,7 +1055,102 @@ func TestCoverageCosmosDB(t *testing.T) { } -func TestCoverageDataFactoryLinkedServices(t *testing.T) { +func TestCoverage_DataFactoryPipelines(t *testing.T) { + tc := testCase{ + name: "DataFactoryPipelines", + apiVersion: "2018-06-01", + resourceType: "Microsoft.DataFactory/factories/pipelines@2018-06-01", + apiPath: "/subscriptions/12345678-1234-1234-1234-12345678abc/resourceGroups/exampleResourceGroup/providers/Microsoft.DataFactory/factories/exampleFactoryName/pipelines/examplePipeline", + rawRequest: []string{`{ + "properties": { + "activities": [ + { + "type": "ForEach", + "typeProperties": { + "isSequential": true, + "items": { + "value": "@pipeline().parameters.OutputBlobNameList", + "type": "Expression" + }, + "activities": [ + { + "type": "Copy", + "typeProperties": { + "source": { + "type": "BlobSource" + }, + "sink": { + "type": "BlobSink" + }, + "dataIntegrationUnits": 32 + }, + "inputs": [ + { + "referenceName": "exampleDataset", + "parameters": { + "MyFolderPath": "examplecontainer", + "MyFileName": "examplecontainer.csv" + }, + "type": "DatasetReference" + } + ], + "outputs": [ + { + "referenceName": "exampleDataset", + "parameters": { + "MyFolderPath": "examplecontainer", + "MyFileName": { + "value": "@item()", + "type": "Expression" + } + }, + "type": "DatasetReference" + } + ], + "name": "ExampleCopyActivity" + } + ] + }, + "name": "ExampleForeachActivity" + } + ], + "parameters": { + "OutputBlobNameList": { + "type": "Array" + }, + "JobId": { + "type": "String" + } + }, + "variables": { + "TestVariableArray": { + "type": "Array" + } + }, + "runDimensions": { + "JobId": { + "value": "@pipeline().parameters.JobId", + "type": "Expression" + } + }, + "policy": { + "elapsedTimeMetric": { + "duration": "0.00:10:00" + } + } + } +} +`, + }, + } + + _, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } +} + +func TestCoverage_DataFactoryLinkedServices(t *testing.T) { tc := testCase{ name: "DataFactoryLinkedServices", resourceType: "Microsoft.DataFactory/factories/linkedServices@2018-06-01", @@ -1013,8 +1175,8 @@ func TestCoverageDataFactoryLinkedServices(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 3 { - t.Fatalf("expected TotalCount 3, got %d", model.CoveredCount) + if model.CoveredCount != 2 { + t.Fatalf("expected TotalCount 2, got %d", model.CoveredCount) } if model.Properties == nil { @@ -1029,8 +1191,8 @@ func TestCoverageDataFactoryLinkedServices(t *testing.T) { t.Fatalf("expected properties type property, got none") } - if !(*(*model.Properties)["properties"].Properties)["type"].IsAnyCovered { - t.Fatalf("expected properties type IsAnyCovered true, got false") + if (*(*model.Properties)["properties"].Properties)["type"].IsAnyCovered { + t.Fatalf("expected properties type IsAnyCovered false, got true") } if (*model.Properties)["properties"].Discriminator == nil { @@ -1088,6 +1250,8 @@ func testCoverage(t *testing.T, tc testCase) (*coverage.Model, error) { tc.apiVersion, ) + t.Logf("swaggerModel: %+v", swaggerModel) + if err != nil { return nil, fmt.Errorf("get model info from index: %+v", err) } @@ -1107,14 +1271,13 @@ func testCoverage(t *testing.T, tc testCase) (*coverage.Model, error) { model.MarkCovered(request) } + model.CountCoverage() + out, err := json.MarshalIndent(model, "", "\t") if err != nil { t.Error(err) } - - t.Logf("expanded model %s", string(out)) - - model.CountCoverage() + t.Logf("coverage model %s", string(out)) coverageReport := coverage.CoverageReport{ Coverages: map[coverage.ArmResource]*coverage.Model{ diff --git a/coverage/expand.go b/coverage/expand.go index ff40f59f..99e0589e 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/golang-lru/v2" ) +// http://azure.github.io/autorest/extensions/#x-ms-discriminator-value const msExtensionDiscriminator = "x-ms-discriminator-value" var ( @@ -18,7 +19,7 @@ var ( swaggerCache, _ = lru.New[string, *loads.Document](20) // {swaggerPath: {parentModelName: {childModelName: nil}}} - variantTableCache, _ = lru.New[string, map[string]map[string]interface{}](10) + allOfTableCache, _ = lru.New[string, map[string]map[string]interface{}](10) ) func loadSwagger(swaggerPath string) (*loads.Document, error) { @@ -34,8 +35,8 @@ func loadSwagger(swaggerPath string) (*loads.Document, error) { return doc, nil } -func getVariantTable(swaggerPath string) (map[string]map[string]interface{}, error) { - if vt, ok := variantTableCache.Get(swaggerPath); ok { +func getAllOfTable(swaggerPath string) (map[string]map[string]interface{}, error) { + if vt, ok := allOfTableCache.Get(swaggerPath); ok { return vt, nil } @@ -45,29 +46,27 @@ func getVariantTable(swaggerPath string) (map[string]map[string]interface{}, err } spec := doc.Spec() - variantsTable := map[string]map[string]interface{}{} + allOfTable := map[string]map[string]interface{}{} for k, v := range spec.Definitions { - if v.Extensions[msExtensionDiscriminator] != nil && len(v.AllOf) > 0 { - for _, variant := range v.AllOf { - if variant.Ref.String() != "" { - resolved, err := openapiSpec.ResolveRefWithBase(spec, &variant.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) - if err != nil { - log.Fatalf("[ERROR] resolve %s: %v", variant.Ref.String(), err) + if len(v.AllOf) > 0 { + for _, allOf := range v.AllOf { + if allOf.Ref.String() != "" { + modelName, relativePath := SchemaNamePathFromRef(allOf.Ref) + if relativePath != "" { + continue } - if resolved.Extensions[msExtensionDiscriminator] != nil || resolved.Discriminator != "" { - modelName, _ := SchemaNamePathFromRef(variant.Ref) - if variantsTable[modelName] == nil { - variantsTable[modelName] = map[string]interface{}{} - } - variantsTable[modelName][k] = nil + + if _, ok := allOfTable[modelName]; !ok { + allOfTable[modelName] = map[string]interface{}{} } + allOfTable[modelName][k] = nil } } } } - variantTableCache.Add(swaggerPath, variantsTable) - return variantsTable, nil + allOfTableCache.Add(swaggerPath, allOfTable) + return allOfTable, nil } func Expand(modelName, swaggerPath string) (*Model, error) { @@ -147,26 +146,29 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s properties := make(map[string]*Model) // expand ref + a := input.Ref.String() + a = a if input.Ref.String() != "" { resolved, err := openapiSpec.ResolveRefWithBase(root, &input.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { - log.Fatalf("[ERROR] resolve %s: %v", input.Ref.String(), err) + log.Fatalf("[ERROR] resolve ref %s from %s: %v", input.Ref.String(), swaggerPath, err) } + refSwaggerPath := swaggerPath modelName, relativePath := SchemaNamePathFromRef(input.Ref) if relativePath != "" { - swaggerPath = filepath.Join(filepath.Dir(swaggerPath), relativePath) - swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) + refSwaggerPath = filepath.Join(filepath.Dir(swaggerPath), relativePath) + refSwaggerPath = strings.Replace(refSwaggerPath, "https:/", "https://", 1) - doc, err := loadSwagger(swaggerPath) + doc, err := loadSwagger(refSwaggerPath) if err != nil { - log.Fatalf("[ERROR] load swagger %s: %v", swaggerPath, err) + log.Fatalf("[ERROR] load swagger %s: %v", refSwaggerPath, err) } root = doc.Spec() } - referenceModel := expandSchema(*resolved, swaggerPath, modelName, identifier, root, resolvedDiscriminator, resolvedModel) + referenceModel := expandSchema(*resolved, refSwaggerPath, modelName, identifier, root, resolvedDiscriminator, resolvedModel) if referenceModel.Properties != nil { for k, v := range *referenceModel.Properties { properties[k] = v @@ -217,6 +219,11 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s properties[k] = v } } + + // the model should be a variant if its allOf contains a discriminator + if allOf.Discriminator != nil { + output.Discriminator = allOf.Discriminator + } } if len(properties) > 0 { @@ -252,29 +259,32 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s delete(resolvedModel, modelName) // expand variants - if input.Discriminator != "" { + if input.Discriminator != "" || output.Discriminator != nil { if _, hasResolvedDiscriminator := resolvedDiscriminator[modelName]; !hasResolvedDiscriminator { - resolvedDiscriminator[modelName] = nil - variants := make(map[string]*Model) - - variantsTable, err := getVariantTable(swaggerPath) + allOfTable, err := getAllOfTable(swaggerPath) if err != nil { log.Fatalf("[ERROR] get variant table %s: %v", swaggerPath, err) } - varSet, ok := variantsTable[modelName] + + varSet, ok := allOfTable[modelName] if ok { + resolvedDiscriminator[modelName] = nil + variants := map[string]*Model{ + modelName: nil, + } + // level order traverse to find all variants for len(varSet) > 0 { tempVarSet := make(map[string]interface{}) - for variantModel := range varSet { - schema := root.(*openapiSpec.Swagger).Definitions[variantModel] - variantName := variantModel + for variantModelName := range varSet { + schema := root.(*openapiSpec.Swagger).Definitions[variantModelName] + variantName := variantModelName if variantNameRaw, ok := schema.Extensions[msExtensionDiscriminator]; ok && variantNameRaw != nil { variantName = variantNameRaw.(string) } - resolved := expandSchema(schema, swaggerPath, variantModel, identifier+"{"+variantName+"}", root, resolvedDiscriminator, resolvedModel) + resolved := expandSchema(schema, swaggerPath, variantModelName, identifier+"{"+variantName+"}", root, resolvedDiscriminator, resolvedModel) variants[variantName] = resolved - if varVarSet, ok := variantsTable[variantModel]; ok { + if varVarSet, ok := allOfTable[variantModelName]; ok { for v := range varVarSet { tempVarSet[v] = nil } @@ -282,11 +292,12 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s } varSet = tempVarSet } + delete(resolvedDiscriminator, modelName) + if input.Discriminator != "" { + output.Discriminator = &input.Discriminator + } + output.Variants = &variants } - - delete(resolvedDiscriminator, modelName) - output.Discriminator = &input.Discriminator - output.Variants = &variants } } diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 6b627752..1bfcbbce 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -2,6 +2,7 @@ package coverage_test import ( "encoding/json" + "fmt" "os" "path/filepath" "runtime" @@ -31,6 +32,7 @@ func TestExpand(t *testing.T) { // try to expand all PUT and POST models func TestExpandAll(t *testing.T) { + // e.g., AZURE_REST_REPO_DIR="/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/" azureRepoDir := os.Getenv("AZURE_REST_REPO_DIR") if azureRepoDir == "" { t.Skip("AZURE_REST_REPO_DIR is not set") @@ -82,31 +84,46 @@ func TestExpandAll(t *testing.T) { swaggerPath := filepath.Join(azureRepoDir, ref.GetURL().Path) operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: azureRepoDir + "/" + strings.Split(ref.GetURL().Path, "/")[0]}) if err != nil { - t.Error(err) - return + panic(fmt.Errorf("resolve operation %q from %s: %v", ref.String(), swaggerPath, err)) } var modelName string for _, param := range operation.Parameters { + paramRef := param.Ref + if paramRef.String() != "" { + refParam, err := openapispec.ResolveParameterWithBase(nil, param.Ref, &openapispec.ExpandOptions{RelativeBase: swaggerPath}) + if err != nil { + panic(fmt.Errorf("resolve parameter %q from %s: %v", param.Ref.String(), swaggerPath, err)) + } + + // Update the param + param = *refParam + } if param.In == "body" { + if paramRef.String() != "" { + _, paramRelativePath := coverage.SchemaNamePathFromRef(paramRef) + if paramRelativePath != "" { + swaggerPath = filepath.Join(filepath.Dir(swaggerPath), paramRelativePath) + } + } + var modelRelativePath string modelName, modelRelativePath = coverage.SchemaNamePathFromRef(param.Schema.Ref) if modelRelativePath != "" { swaggerPath = filepath.Join(filepath.Dir(swaggerPath), modelRelativePath) } + break } } // post may have no model if operation.Put != nil && modelName == "" { - t.Error("modelName is empty") - return + panic(fmt.Errorf("resolve %s from %s: modelName is empty", ref.String(), swaggerPath)) } _, err = coverage.Expand(modelName, swaggerPath) if err != nil { - t.Error(err) - return + panic(fmt.Errorf("expand %s from %s: %+v", modelName, swaggerPath, err)) } // clean up diff --git a/coverage/index.go b/coverage/index.go index 3ccde46b..cf9a3b30 100644 --- a/coverage/index.go +++ b/coverage/index.go @@ -71,7 +71,10 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) } swaggerPath := filepath.Join(azureRepoURL, ref.GetURL().Path) - operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: azureRepoURL + "/" + strings.Split(ref.GetURL().Path, "/")[0]}) + swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) + + relativeBase := azureRepoURL + strings.Split(ref.GetURL().Path, "/")[0] + operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: relativeBase}) if err != nil { return nil, err @@ -82,16 +85,38 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) var modelName string for _, param := range operation.Parameters { + paramRef := param.Ref + if paramRef.String() != "" { + refParam, err := openapispec.ResolveParameterWithBase(nil, param.Ref, &openapispec.ExpandOptions{RelativeBase: swaggerPath}) + if err != nil { + return nil, fmt.Errorf("resolve param ref %q: %v", param.Ref.String(), err) + } + + // Update the param + param = *refParam + } if param.In == "body" { + if paramRef.String() != "" { + _, paramRelativePath := SchemaNamePathFromRef(paramRef) + if paramRelativePath != "" { + swaggerPath = filepath.Join(filepath.Dir(swaggerPath), paramRelativePath) + swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) + } + } + var modelRelativePath string modelName, modelRelativePath = SchemaNamePathFromRef(param.Schema.Ref) if modelRelativePath != "" { swaggerPath = filepath.Join(filepath.Dir(swaggerPath), modelRelativePath) + swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) } + break } } - swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) + if modelName == "" { + return nil, fmt.Errorf("PUT model not found for %s:%s", swaggerPath, apiPath) + } return &SwaggerModel{ ApiPath: apiPath, diff --git a/coverage/index_test.go b/coverage/index_test.go index 098595c3..b6f8a1ca 100644 --- a/coverage/index_test.go +++ b/coverage/index_test.go @@ -6,7 +6,7 @@ import ( "github.com/ms-henglu/armstrong/coverage" ) -func TestGetModelInfoFromIndex(t *testing.T) { +func TestGetModelInfoFromIndex_DataCollectionRule(t *testing.T) { apiVersion := "2022-06-01" swaggerModel, err := coverage.GetModelInfoFromIndex( "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/test-resources/providers/Microsoft.Insights/dataCollectionRules/testDCR", @@ -31,3 +31,29 @@ func TestGetModelInfoFromIndex(t *testing.T) { t.Fatalf("expected modelSwaggerPath %s, got %s", expectedModelSwaggerPath, swaggerModel.SwaggerPath) } } + +func TestGetModelInfoFromIndex_DeviceSecurityGroups(t *testing.T) { + apiVersion := "2019-08-01" + swaggerModel, err := coverage.GetModelInfoFromIndex( + "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/SampleRG/providers/Microsoft.Devices/iotHubs/sampleiothub/providers/Microsoft.Security/deviceSecurityGroups/samplesecuritygroup", + apiVersion, + ) + if err != nil { + t.Fatalf("get model info from index error: %+v", err) + } + + expectedApiPath := "/{resourceId}/providers/Microsoft.Security/deviceSecurityGroups/{deviceSecurityGroupName}" + if swaggerModel.ApiPath != expectedApiPath { + t.Fatalf("expected apiPath %s, got %s", expectedApiPath, swaggerModel.ApiPath) + } + + expectedModelName := "DeviceSecurityGroup" + if swaggerModel.ModelName != expectedModelName { + t.Fatalf("expected modelName %s, got %s", expectedModelName, swaggerModel.ModelName) + } + + expectedModelSwaggerPath := "https://raw.githubusercontent.com/Azure/azure-rest-api-specs/main/specification/security/resource-manager/Microsoft.Security/stable/2019-08-01/deviceSecurityGroups.json" + if swaggerModel.SwaggerPath != expectedModelSwaggerPath { + t.Fatalf("expected modelSwaggerPath %s, got %s", expectedModelSwaggerPath, swaggerModel.SwaggerPath) + } +} diff --git a/report/pass_report.go b/report/pass_report.go index b2b56084..112ba786 100644 --- a/report/pass_report.go +++ b/report/pass_report.go @@ -110,15 +110,31 @@ func getReport(model *coverage.Model) []string { } if v.Variants != nil { - for variantType, variant := range *v.Variants { - out = append(out, getChildReport(fmt.Sprintf("%s{%s}", k, variantType), variant)) + for variantName, variant := range *v.Variants { + variantKey := fmt.Sprintf("%s{%s}", k, variantName) + if variant == nil { + // reference to self + out = append(out, getChildReport(variantKey, v)) + continue + } + + out = append(out, getChildReport(variantKey, variant)) } + continue } if v.Item != nil && v.Item.Variants != nil { for variantType, variant := range *v.Item.Variants { - out = append(out, getChildReport(fmt.Sprintf("%s{%s}", k, variantType), variant)) + variantKey := fmt.Sprintf("%s{%s}", k, variantType) + if variant == nil { + // reference to self + out = append(out, getChildReport(variantKey, v)) + continue + } + + out = append(out, getChildReport(variantKey, variant)) } + continue } out = append(out, getChildReport(k, v)) From 3f27175e1c6450dbcafbe8216d27e8918a15a453 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Fri, 14 Jul 2023 20:51:21 +0800 Subject: [PATCH 02/10] fix --- coverage/coverage_test.go | 76 +++++++++++++++++++++------------------ coverage/expand.go | 5 ++- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/coverage/coverage_test.go b/coverage/coverage_test.go index f810c82b..1e7cd897 100644 --- a/coverage/coverage_test.go +++ b/coverage/coverage_test.go @@ -21,39 +21,6 @@ type testCase struct { resourceType string } -func TestCoverage_DataMigrationTasks(t *testing.T) { - tc := testCase{ - name: "DataMigrationTasks", - resourceType: "Microsoft.DataMigration/services/projects/tasks@2021-06-30", - apiVersion: "2021-06-30", - apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/DmsSdkRg/providers/Microsoft.DataMigration/services/DmsSdkService/projects/DmsSdkProject/tasks/DmsSdkTask", - rawRequest: []string{`{ - "taskType": "ConnectToTarget.SqlDb", - "input": { - "targetConnectionInfo": { - "type": "SqlConnectionInfo", - "dataSource": "ssma-test-server.database.windows.net", - "authentication": "SqlAuthentication", - "encryptConnection": true, - "trustServerCertificate": true, - "userName": "testuser", - "password": "testpassword" - } - } -}`, - }, - } - - model, err := testCoverage(t, tc) - if err != nil { - t.Fatalf("process coverage: %+v", err) - } - - if model.CoveredCount != 1 { - t.Fatalf("expected CoveredCount 1, got %d", model.CoveredCount) - } -} - func TestCoverage_ResourceGroup(t *testing.T) { tc := testCase{ name: "ResourceGroup", @@ -125,6 +92,41 @@ func TestCoverage_DeviceSecurityGroup(t *testing.T) { } } +func TestCoverage_DataMigrationTasks(t *testing.T) { + tc := testCase{ + name: "DataMigrationTasks", + resourceType: "Microsoft.DataMigration/services/projects/tasks@2021-06-30", + apiVersion: "2021-06-30", + apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/DmsSdkRg/providers/Microsoft.DataMigration/services/DmsSdkService/projects/DmsSdkProject/tasks/DmsSdkTask", + rawRequest: []string{`{ + "properties": { + "taskType": "ConnectToTarget.SqlDb", + "input": { + "targetConnectionInfo": { + "type": "SqlConnectionInfo", + "dataSource": "ssma-test-server.database.windows.net", + "authentication": "SqlAuthentication", + "encryptConnection": true, + "trustServerCertificate": true, + "userName": "testuser", + "password": "testpassword" + } + } + } +}`, + }, + } + + model, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } + + if model.CoveredCount != 8 { + t.Fatalf("expected CoveredCount 8, got %d", model.CoveredCount) + } +} + func TestCoverage_KeyVault(t *testing.T) { tc := testCase{ name: "KeyVault", @@ -1261,6 +1263,12 @@ func testCoverage(t *testing.T, tc testCase) (*coverage.Model, error) { return nil, fmt.Errorf("expand model: %+v", err) } + out, err := json.MarshalIndent(model, "", "\t") + if err != nil { + t.Error(err) + } + t.Logf("expand model %s", string(out)) + for _, rq := range tc.rawRequest { request := map[string]interface{}{} err = json.Unmarshal([]byte(rq), &request) @@ -1273,7 +1281,7 @@ func testCoverage(t *testing.T, tc testCase) (*coverage.Model, error) { model.CountCoverage() - out, err := json.MarshalIndent(model, "", "\t") + out, err = json.MarshalIndent(model, "", "\t") if err != nil { t.Error(err) } diff --git a/coverage/expand.go b/coverage/expand.go index 99e0589e..505c1ce3 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -146,8 +146,6 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s properties := make(map[string]*Model) // expand ref - a := input.Ref.String() - a = a if input.Ref.String() != "" { resolved, err := openapiSpec.ResolveRefWithBase(root, &input.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { @@ -157,7 +155,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s refSwaggerPath := swaggerPath modelName, relativePath := SchemaNamePathFromRef(input.Ref) if relativePath != "" { - refSwaggerPath = filepath.Join(filepath.Dir(swaggerPath), relativePath) + refSwaggerPath = filepath.Join(filepath.Dir(refSwaggerPath), relativePath) refSwaggerPath = strings.Replace(refSwaggerPath, "https:/", "https://", 1) doc, err := loadSwagger(refSwaggerPath) @@ -282,6 +280,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if variantNameRaw, ok := schema.Extensions[msExtensionDiscriminator]; ok && variantNameRaw != nil { variantName = variantNameRaw.(string) } + resolved := expandSchema(schema, swaggerPath, variantModelName, identifier+"{"+variantName+"}", root, resolvedDiscriminator, resolvedModel) variants[variantName] = resolved if varVarSet, ok := allOfTable[variantModelName]; ok { From 1ba819ffddf3c0563ee186c05103f7632e126fb4 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Tue, 18 Jul 2023 14:25:47 +0800 Subject: [PATCH 03/10] fix ref relative path --- coverage/expand.go | 30 ++++++++++++++++++------------ coverage/expand_test.go | 13 +++---------- coverage/index.go | 17 +++-------------- 3 files changed, 24 insertions(+), 36 deletions(-) diff --git a/coverage/expand.go b/coverage/expand.go index 505c1ce3..f32c512c 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -51,8 +51,8 @@ func getAllOfTable(swaggerPath string) (map[string]map[string]interface{}, error if len(v.AllOf) > 0 { for _, allOf := range v.AllOf { if allOf.Ref.String() != "" { - modelName, relativePath := SchemaNamePathFromRef(allOf.Ref) - if relativePath != "" { + modelName, absPath := SchemaNamePathFromRef(swaggerPath, allOf.Ref) + if absPath != swaggerPath { continue } @@ -152,12 +152,8 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s log.Fatalf("[ERROR] resolve ref %s from %s: %v", input.Ref.String(), swaggerPath, err) } - refSwaggerPath := swaggerPath - modelName, relativePath := SchemaNamePathFromRef(input.Ref) - if relativePath != "" { - refSwaggerPath = filepath.Join(filepath.Dir(refSwaggerPath), relativePath) - refSwaggerPath = strings.Replace(refSwaggerPath, "https:/", "https://", 1) - + modelName, refSwaggerPath := SchemaNamePathFromRef(swaggerPath, input.Ref) + if refSwaggerPath != swaggerPath { doc, err := loadSwagger(refSwaggerPath) if err != nil { log.Fatalf("[ERROR] load swagger %s: %v", refSwaggerPath, err) @@ -303,10 +299,20 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s return &output } -func SchemaNamePathFromRef(ref openapiSpec.Ref) (name string, path string) { - if ref.GetURL() == nil { +func SchemaNamePathFromRef(swaggerPath string, ref openapiSpec.Ref) (name string, path string) { + url := ref.GetURL() + if url == nil { return "", "" } - fragments := strings.Split(ref.GetURL().Fragment, "/") - return fragments[len(fragments)-1], ref.GetURL().Path + + path = url.Path + if path == "" { + path = swaggerPath + } else if !filepath.IsAbs(path) { + path = filepath.Join(filepath.Dir(swaggerPath), path) + path = strings.Replace(path, "https:/", "https://", 1) + } + + fragments := strings.Split(url.Fragment, "/") + return fragments[len(fragments)-1], path } diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 1bfcbbce..6003da12 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -15,7 +15,7 @@ import ( "github.com/ms-henglu/armstrong/coverage" ) -func TestExpand(t *testing.T) { +func TestExpand_MediaTranform(t *testing.T) { modelName := "Transform" modelSwaggerPath := "https://raw.githubusercontent.com/Azure/azure-rest-api-specs/main/specification/mediaservices/resource-manager/Microsoft.Media/Encoding/stable/2022-07-01/Encoding.json" model, err := coverage.Expand(modelName, modelSwaggerPath) @@ -101,17 +101,10 @@ func TestExpandAll(t *testing.T) { } if param.In == "body" { if paramRef.String() != "" { - _, paramRelativePath := coverage.SchemaNamePathFromRef(paramRef) - if paramRelativePath != "" { - swaggerPath = filepath.Join(filepath.Dir(swaggerPath), paramRelativePath) - } + _, swaggerPath = coverage.SchemaNamePathFromRef(swaggerPath, paramRef) } - var modelRelativePath string - modelName, modelRelativePath = coverage.SchemaNamePathFromRef(param.Schema.Ref) - if modelRelativePath != "" { - swaggerPath = filepath.Join(filepath.Dir(swaggerPath), modelRelativePath) - } + modelName, swaggerPath = coverage.SchemaNamePathFromRef(swaggerPath, param.Schema.Ref) break } } diff --git a/coverage/index.go b/coverage/index.go index cf9a3b30..52f61e47 100644 --- a/coverage/index.go +++ b/coverage/index.go @@ -7,7 +7,6 @@ import ( "log" "net/http" "net/url" - "path/filepath" "strings" openapispec "github.com/go-openapi/spec" @@ -70,8 +69,7 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) return nil, err } - swaggerPath := filepath.Join(azureRepoURL, ref.GetURL().Path) - swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) + _, swaggerPath := SchemaNamePathFromRef(azureRepoURL, openapispec.Ref{Ref: *ref}) relativeBase := azureRepoURL + strings.Split(ref.GetURL().Path, "/")[0] operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: relativeBase}) @@ -97,19 +95,10 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) } if param.In == "body" { if paramRef.String() != "" { - _, paramRelativePath := SchemaNamePathFromRef(paramRef) - if paramRelativePath != "" { - swaggerPath = filepath.Join(filepath.Dir(swaggerPath), paramRelativePath) - swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) - } + _, swaggerPath = SchemaNamePathFromRef(swaggerPath, paramRef) } - var modelRelativePath string - modelName, modelRelativePath = SchemaNamePathFromRef(param.Schema.Ref) - if modelRelativePath != "" { - swaggerPath = filepath.Join(filepath.Dir(swaggerPath), modelRelativePath) - swaggerPath = strings.Replace(swaggerPath, "https:/", "https://", 1) - } + modelName, swaggerPath = SchemaNamePathFromRef(swaggerPath, param.Schema.Ref) break } } From a69589eb2bd4bbf245496fb32d6350db08baccb6 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:28:13 +0800 Subject: [PATCH 04/10] fix --- coverage/coverage.go | 1 + coverage/coverage_test.go | 24 ++++++++++++++++++++ coverage/expand.go | 28 +++++++++++++++-------- coverage/expand_test.go | 48 +++++++++------------------------------ coverage/index.go | 39 +++++++++++++++++++------------ coverage/report.go | 6 ++--- report/pass_report.go | 6 ++--- 7 files changed, 85 insertions(+), 67 deletions(-) diff --git a/coverage/coverage.go b/coverage/coverage.go index 953e930d..7310d921 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -22,6 +22,7 @@ type Model struct { IsReadOnly bool `json:"IsReadOnly,omitempty"` IsRequired bool `json:"IsRequired,omitempty"` Item *Model `json:"Item,omitempty"` + ModelName string `json:"ModelName,omitempty"` Properties *map[string]*Model `json:"Properties,omitempty"` SourceFile string `json:"SourceFile,omitempty"` TotalCount int `json:"TotalCount,omitempty"` diff --git a/coverage/coverage_test.go b/coverage/coverage_test.go index 1e7cd897..dda8fc88 100644 --- a/coverage/coverage_test.go +++ b/coverage/coverage_test.go @@ -92,6 +92,30 @@ func TestCoverage_DeviceSecurityGroup(t *testing.T) { } } +func TestCoverage_DataMigrationServiceTasks(t *testing.T) { + // TODO: support cross file discriminator reference, e.g., https://github.com/Azure/azure-rest-api-specs/blob/0ab5469dc0d75594f5747493dcfe8774e22d728f/specification/datamigration/resource-manager/Microsoft.DataMigration/stable/2021-06-30/definitions/ServiceTasks.json#L39 + tc := testCase{ + name: "DataMigrationServiceTasks", + resourceType: "Microsoft.DataMigration/services/serviceTasks@2021-06-30", + apiVersion: "2021-06-30", + apiPath: "/subscriptions/fc04246f-04c5-437e-ac5e-206a19e7193f/resourceGroups/DmsSdkRg/providers/Microsoft.DataMigration/services/DmsSdkService/serviceTasks/DmsSdkTask", + rawRequest: []string{`{ + "properties": { + "taskType": "Service.Check.OCI", + "input": { + "serverVersion": "NA" + } + } +}`, + }, + } + + _, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } +} + func TestCoverage_DataMigrationTasks(t *testing.T) { tc := testCase{ name: "DataMigrationTasks", diff --git a/coverage/expand.go b/coverage/expand.go index f32c512c..e22eb161 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -83,18 +83,27 @@ func Expand(modelName, swaggerPath string) (*Model, error) { modelSchema, ok := spec.Definitions[modelName] if !ok { - return nil, fmt.Errorf("%s not found in the definition of %s", modelName, swaggerPath) + _, ok := spec.Parameters[modelName] + if ok { + // https://github.com/Azure/azure-rest-api-specs/blob/fef27735a1c8498d970be905bc45b2e4892fc3b0/specification/vmware/resource-manager/Microsoft.AVS/stable/2021-06-01/vmware.json#L251 + log.Printf("[WARN] Parameter %s is used as a model in %s", modelName, swaggerPath) + return &Model{}, nil + } else { + return nil, fmt.Errorf("%s not found in the definition of %s", modelName, swaggerPath) + } } output := expandSchema(modelSchema, swaggerPath, modelName, "#", spec, map[string]interface{}{}, map[string]interface{}{}) - output.SourceFile = swaggerPath - return output, nil } func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier string, root interface{}, resolvedDiscriminator map[string]interface{}, resolvedModel map[string]interface{}) *Model { - output := Model{Identifier: identifier} + output := Model{ + Identifier: identifier, + ModelName: modelName, + SourceFile: swaggerPath, + } if _, ok := resolvedModel[modelName]; ok { return &output @@ -149,20 +158,21 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if input.Ref.String() != "" { resolved, err := openapiSpec.ResolveRefWithBase(root, &input.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { - log.Fatalf("[ERROR] resolve ref %s from %s: %v", input.Ref.String(), swaggerPath, err) + log.Fatalf("[ERROR] resolve ref %s from %s: %+v", input.Ref.String(), swaggerPath, err) } modelName, refSwaggerPath := SchemaNamePathFromRef(swaggerPath, input.Ref) + refRoot := root if refSwaggerPath != swaggerPath { doc, err := loadSwagger(refSwaggerPath) if err != nil { - log.Fatalf("[ERROR] load swagger %s: %v", refSwaggerPath, err) + log.Fatalf("[ERROR] load swagger %s: %+v", refSwaggerPath, err) } - root = doc.Spec() + refRoot = doc.Spec() } - referenceModel := expandSchema(*resolved, refSwaggerPath, modelName, identifier, root, resolvedDiscriminator, resolvedModel) + referenceModel := expandSchema(*resolved, refSwaggerPath, modelName, identifier, refRoot, resolvedDiscriminator, resolvedModel) if referenceModel.Properties != nil { for k, v := range *referenceModel.Properties { properties[k] = v @@ -257,7 +267,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if _, hasResolvedDiscriminator := resolvedDiscriminator[modelName]; !hasResolvedDiscriminator { allOfTable, err := getAllOfTable(swaggerPath) if err != nil { - log.Fatalf("[ERROR] get variant table %s: %v", swaggerPath, err) + log.Fatalf("[ERROR] get variant table %s: %+v", swaggerPath, err) } varSet, ok := allOfTable[modelName] diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 6003da12..2d3eb73a 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" "runtime" "strings" "sync" @@ -15,7 +14,7 @@ import ( "github.com/ms-henglu/armstrong/coverage" ) -func TestExpand_MediaTranform(t *testing.T) { +func TestExpand_MediaTransform(t *testing.T) { modelName := "Transform" modelSwaggerPath := "https://raw.githubusercontent.com/Azure/azure-rest-api-specs/main/specification/mediaservices/resource-manager/Microsoft.Media/Encoding/stable/2022-07-01/Encoding.json" model, err := coverage.Expand(modelName, modelSwaggerPath) @@ -32,13 +31,15 @@ func TestExpand_MediaTranform(t *testing.T) { // try to expand all PUT and POST models func TestExpandAll(t *testing.T) { - // e.g., AZURE_REST_REPO_DIR="/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/" azureRepoDir := os.Getenv("AZURE_REST_REPO_DIR") if azureRepoDir == "" { t.Skip("AZURE_REST_REPO_DIR is not set") } - t.Logf("azure repo dir: %s", azureRepoDir) + if !strings.HasSuffix(azureRepoDir, "specification/") { + t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/\"") + } + t.Logf("azure repo dir: %s", azureRepoDir) index, err := coverage.GetIndex() if err != nil { t.Fatal(err) @@ -81,46 +82,19 @@ func TestExpandAll(t *testing.T) { go func(i int) { for ref := range refChan { t.Logf("%v ref: %v", i, ref.String()) - swaggerPath := filepath.Join(azureRepoDir, ref.GetURL().Path) - operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: azureRepoDir + "/" + strings.Split(ref.GetURL().Path, "/")[0]}) - if err != nil { - panic(fmt.Errorf("resolve operation %q from %s: %v", ref.String(), swaggerPath, err)) - } - - var modelName string - for _, param := range operation.Parameters { - paramRef := param.Ref - if paramRef.String() != "" { - refParam, err := openapispec.ResolveParameterWithBase(nil, param.Ref, &openapispec.ExpandOptions{RelativeBase: swaggerPath}) - if err != nil { - panic(fmt.Errorf("resolve parameter %q from %s: %v", param.Ref.String(), swaggerPath, err)) - } - // Update the param - param = *refParam - } - if param.In == "body" { - if paramRef.String() != "" { - _, swaggerPath = coverage.SchemaNamePathFromRef(swaggerPath, paramRef) - } - - modelName, swaggerPath = coverage.SchemaNamePathFromRef(swaggerPath, param.Schema.Ref) - break - } - } - - // post may have no model - if operation.Put != nil && modelName == "" { - panic(fmt.Errorf("resolve %s from %s: modelName is empty", ref.String(), swaggerPath)) + model, err := coverage.GetModelInfoFromIndexRef(openapispec.Ref{Ref: *ref}, azureRepoDir) + if err != nil { + panic(fmt.Errorf("get model info from index ref %s: %+v", ref.String(), err)) } - _, err = coverage.Expand(modelName, swaggerPath) + _, err = coverage.Expand(model.ModelName, model.SwaggerPath) if err != nil { - panic(fmt.Errorf("expand %s from %s: %+v", modelName, swaggerPath, err)) + panic(fmt.Errorf("process %s, expand %s from %s: %+v", ref.String(), model.ModelName, model.SwaggerPath, err)) } // clean up - operation = nil + model = nil ref = nil } diff --git a/coverage/index.go b/coverage/index.go index 52f61e47..fbc23767 100644 --- a/coverage/index.go +++ b/coverage/index.go @@ -27,19 +27,19 @@ func GetIndex() (*azidx.Index, error) { resp, err := http.Get(indexFileURL) if err != nil { - return nil, fmt.Errorf("get index file (%v): %v", indexFileURL, err) + return nil, fmt.Errorf("get index file (%v): %+v", indexFileURL, err) } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("read index file: %v", err) + return nil, fmt.Errorf("read index file: %+v", err) } var index azidx.Index if err := json.Unmarshal(b, &index); err != nil { - return nil, fmt.Errorf("unmarshal index file: %v", err) + return nil, fmt.Errorf("unmarshal index file: %+v", err) } indexCache = &index @@ -62,17 +62,29 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) resourceURL := fmt.Sprintf("https://management.azure.com%s?api-version=%s", resourceId, apiVersion) uRL, err := url.Parse(resourceURL) if err != nil { - return nil, fmt.Errorf("parsing URL %s: %v", resourceURL, err) + return nil, fmt.Errorf("parsing URL %s: %+v", resourceURL, err) } ref, err := index.Lookup("PUT", *uRL) if err != nil { return nil, err } - _, swaggerPath := SchemaNamePathFromRef(azureRepoURL, openapispec.Ref{Ref: *ref}) + model, err := GetModelInfoFromIndexRef(openapispec.Ref{Ref: *ref}, azureRepoURL) + if err != nil { + return nil, err + } + if model.ModelName == "" { + return nil, fmt.Errorf("PUT model not found for %s", ref.String()) + } + + return model, nil +} - relativeBase := azureRepoURL + strings.Split(ref.GetURL().Path, "/")[0] - operation, err := openapispec.ResolvePathItemWithBase(nil, openapispec.Ref{Ref: *ref}, &openapispec.ExpandOptions{RelativeBase: relativeBase}) +func GetModelInfoFromIndexRef(ref openapispec.Ref, swaggerRepo string) (*SwaggerModel, error) { + _, swaggerPath := SchemaNamePathFromRef(swaggerRepo, ref) + + relativeBase := swaggerRepo + strings.Split(ref.GetURL().Path, "/")[0] + operation, err := openapispec.ResolvePathItemWithBase(nil, ref, &openapispec.ExpandOptions{RelativeBase: relativeBase}) if err != nil { return nil, err @@ -87,7 +99,7 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) if paramRef.String() != "" { refParam, err := openapispec.ResolveParameterWithBase(nil, param.Ref, &openapispec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { - return nil, fmt.Errorf("resolve param ref %q: %v", param.Ref.String(), err) + return nil, fmt.Errorf("resolve param ref %q: %+v", param.Ref.String(), err) } // Update the param @@ -95,18 +107,15 @@ func GetModelInfoFromIndex(resourceId, apiVersion string) (*SwaggerModel, error) } if param.In == "body" { if paramRef.String() != "" { - _, swaggerPath = SchemaNamePathFromRef(swaggerPath, paramRef) + modelName, swaggerPath = SchemaNamePathFromRef(swaggerPath, paramRef) } - modelName, swaggerPath = SchemaNamePathFromRef(swaggerPath, param.Schema.Ref) + if param.Schema.Ref.String() != "" { + modelName, swaggerPath = SchemaNamePathFromRef(swaggerPath, param.Schema.Ref) + } break } } - - if modelName == "" { - return nil, fmt.Errorf("PUT model not found for %s:%s", swaggerPath, apiPath) - } - return &SwaggerModel{ ApiPath: apiPath, ModelName: modelName, diff --git a/coverage/report.go b/coverage/report.go index c24fee3e..f9a160c3 100644 --- a/coverage/report.go +++ b/coverage/report.go @@ -26,11 +26,11 @@ func (c *CoverageReport) AddCoverageFromState(resourceId, resourceType string, j swaggerModel, err := GetModelInfoFromIndex(resourceId, apiVersion) if err != nil { - return fmt.Errorf("error find the path for %s from index:%s", resourceId, err) + return fmt.Errorf("error find the path for %s from index: %+v", resourceId, err) } - log.Printf("[INFO] matched API path:%s modelSwawggerPath:%s\n", swaggerModel.ApiPath, swaggerModel.SwaggerPath) + log.Printf("[INFO] matched API path: %s; modelSwawggerPath: %s\n", swaggerModel.ApiPath, swaggerModel.SwaggerPath) resource := ArmResource{ ApiPath: swaggerModel.ApiPath, @@ -40,7 +40,7 @@ func (c *CoverageReport) AddCoverageFromState(resourceId, resourceType string, j if _, ok := c.Coverages[resource]; !ok { expanded, err := Expand(swaggerModel.ModelName, swaggerModel.SwaggerPath) if err != nil { - return fmt.Errorf("error expand model %s property:%s", swaggerModel.ModelName, err) + return fmt.Errorf("error expand model %s property: %+v", swaggerModel.ModelName, err) } c.Coverages[resource] = expanded diff --git a/report/pass_report.go b/report/pass_report.go index 112ba786..58c72518 100644 --- a/report/pass_report.go +++ b/report/pass_report.go @@ -60,10 +60,10 @@ func PassedMarkdownReport(passReport types.PassReport, coverageReport coverage.C [swagger](%[3]v)
-body(%[5]v/%[6]v) +%[5]v(%[6]v/%[7]v)
-%[7]v +%[8]v
@@ -72,7 +72,7 @@ func PassedMarkdownReport(passReport types.PassReport, coverageReport coverage.C --- -`, k.Type, k.ApiPath, v.SourceFile, getStyle(v.IsFullyCovered), v.CoveredCount, v.TotalCount, strings.Join(reportDetail, "\n\n"))) +`, k.Type, k.ApiPath, v.SourceFile, getStyle(v.IsFullyCovered), v.ModelName, v.CoveredCount, v.TotalCount, strings.Join(reportDetail, "\n\n"))) } sort.Strings(coverages) From b6db6a5a65dd9c1d73edb409c88a89cfac5ad2c1 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Wed, 26 Jul 2023 18:00:42 +0800 Subject: [PATCH 05/10] avoid break existing logic --- commands/test.go | 4 ++-- coverage/coverage.go | 2 +- coverage/expand.go | 29 +++++++++++++++-------------- coverage/expand_test.go | 5 +++-- coverage/report.go | 6 ------ report/pass_report.go | 7 +++++++ tf/utils.go | 12 ++++++++++++ 7 files changed, 40 insertions(+), 25 deletions(-) diff --git a/commands/test.go b/commands/test.go index 7046afd1..223f8165 100644 --- a/commands/test.go +++ b/commands/test.go @@ -130,7 +130,7 @@ func (c TestCommand) Execute() int { coverageReport, err := tf.NewCoverageReportFromState(state) if err != nil { - log.Fatalf("[ERROR] error produce coverage report: %+v", err) + log.Printf("[ERROR] error produce coverage report: %+v", err) } log.Printf("[INFO] the coverage report has been produced.") storePassReport(passReport, coverageReport, reportDir, allPassedReportFileName) @@ -158,7 +158,7 @@ func (c TestCommand) Execute() int { passReport := tf.NewPassReport(plan) coverageReport, err := tf.NewCoverageReport(plan) if err != nil { - log.Fatalf("[ERROR] error produce coverage report: %+v", err) + log.Printf("[ERROR] error produce coverage report: %+v", err) } storePassReport(passReport, coverageReport, reportDir, partialPassedReportFileName) diff --git a/coverage/coverage.go b/coverage/coverage.go index 7310d921..add1a025 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -108,7 +108,7 @@ func (m *Model) MarkCovered(root interface{}) { case nil: default: - log.Fatalf("[ERROR] unexpect type %T for json unmarshaled value", value) + log.Printf("[ERROR] unexpect type %T for json unmarshaled value", value) } } diff --git a/coverage/expand.go b/coverage/expand.go index e22eb161..4af075eb 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -3,6 +3,7 @@ package coverage import ( "fmt" "log" + "net/url" "path/filepath" "strings" @@ -158,7 +159,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if input.Ref.String() != "" { resolved, err := openapiSpec.ResolveRefWithBase(root, &input.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { - log.Fatalf("[ERROR] resolve ref %s from %s: %+v", input.Ref.String(), swaggerPath, err) + log.Panicf("[ERROR] resolve ref %s from %s: %+v", input.Ref.String(), swaggerPath, err) } modelName, refSwaggerPath := SchemaNamePathFromRef(swaggerPath, input.Ref) @@ -166,7 +167,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if refSwaggerPath != swaggerPath { doc, err := loadSwagger(refSwaggerPath) if err != nil { - log.Fatalf("[ERROR] load swagger %s: %+v", refSwaggerPath, err) + log.Panicf("[ERROR] load swagger %s: %+v", refSwaggerPath, err) } refRoot = doc.Spec() @@ -267,7 +268,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s if _, hasResolvedDiscriminator := resolvedDiscriminator[modelName]; !hasResolvedDiscriminator { allOfTable, err := getAllOfTable(swaggerPath) if err != nil { - log.Fatalf("[ERROR] get variant table %s: %+v", swaggerPath, err) + log.Panicf("[ERROR] get variant table %s: %+v", swaggerPath, err) } varSet, ok := allOfTable[modelName] @@ -309,20 +310,20 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s return &output } -func SchemaNamePathFromRef(swaggerPath string, ref openapiSpec.Ref) (name string, path string) { - url := ref.GetURL() - if url == nil { +func SchemaNamePathFromRef(swaggerPath string, ref openapiSpec.Ref) (schemaName string, schemaPath string) { + refUrl := ref.GetURL() + if refUrl == nil { return "", "" } - path = url.Path - if path == "" { - path = swaggerPath - } else if !filepath.IsAbs(path) { - path = filepath.Join(filepath.Dir(swaggerPath), path) - path = strings.Replace(path, "https:/", "https://", 1) + schemaPath = refUrl.Path + if schemaPath == "" { + schemaPath = swaggerPath + } else { + swaggerPath, _ := filepath.Split(swaggerPath) + schemaPath, _ = url.JoinPath(swaggerPath, schemaPath) } - fragments := strings.Split(url.Fragment, "/") - return fragments[len(fragments)-1], path + fragments := strings.Split(refUrl.Fragment, "/") + return fragments[len(fragments)-1], schemaPath } diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 2d3eb73a..0853e4ba 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -72,10 +72,9 @@ func TestExpandAll(t *testing.T) { } index = nil - t.Logf("refs numbers: %d", len(refMaps)) - refChan := make(chan *jsonreference.Ref) + totalPropCount := 0 var waitGroup sync.WaitGroup waitGroup.Add(runtime.NumCPU()) for i := 0; i < runtime.NumCPU(); i++ { @@ -105,4 +104,6 @@ func TestExpandAll(t *testing.T) { for _, ref := range refMaps { refChan <- ref } + + t.Logf("refs numbers: %d, total prop count: %d", len(refMaps), totalPropCount) } diff --git a/coverage/report.go b/coverage/report.go index f9a160c3..764ba509 100644 --- a/coverage/report.go +++ b/coverage/report.go @@ -3,7 +3,6 @@ package coverage import ( "fmt" "log" - "regexp" "strings" ) @@ -17,12 +16,7 @@ type CoverageReport struct { } func (c *CoverageReport) AddCoverageFromState(resourceId, resourceType string, jsonBody map[string]interface{}) error { - var err error - apiVersion := strings.Split(resourceType, "@")[1] - if !regexp.MustCompile(`^[0-9]{4}-[0-9]{2}-[0-9]{2}$`).MatchString(apiVersion) { - return fmt.Errorf("could not parse apiVersion from resourceType: %s", resourceType) - } swaggerModel, err := GetModelInfoFromIndex(resourceId, apiVersion) if err != nil { diff --git a/report/pass_report.go b/report/pass_report.go index 58c72518..2bd86206 100644 --- a/report/pass_report.go +++ b/report/pass_report.go @@ -25,6 +25,12 @@ func PassedMarkdownReport(passReport types.PassReport, coverageReport coverage.C content := passedReportTemplate content = strings.ReplaceAll(content, "${resource_type}", strings.Join(resourceTypes, "\n")) + content = addCoverageReport(content, coverageReport) + + return content +} + +func addCoverageReport(content string, coverageReport coverage.CoverageReport) string { fullyCoveredPath := make([]string, 0) partiallyCoveredPath := make([]string, 0) for k, v := range coverageReport.Coverages { @@ -77,6 +83,7 @@ func PassedMarkdownReport(passReport types.PassReport, coverageReport coverage.C sort.Strings(coverages) content = strings.ReplaceAll(content, "${coverage_details}", strings.Join(coverages, "\n")) + return content } diff --git a/tf/utils.go b/tf/utils.go index f3676e2d..dd739eb0 100644 --- a/tf/utils.go +++ b/tf/utils.go @@ -147,6 +147,12 @@ func NewPassReport(plan *tfjson.Plan) types.PassReport { } func NewCoverageReportFromState(state *tfjson.State) (coverage.CoverageReport, error) { + defer func() { + if r := recover(); r != nil { + log.Printf("[ERROR] panic when producing coverage report from state: %+v", r) + } + }() + out := coverage.CoverageReport{ Coverages: make(map[coverage.ArmResource]*coverage.Model, 0), } @@ -182,6 +188,12 @@ func NewCoverageReportFromState(state *tfjson.State) (coverage.CoverageReport, e } func NewCoverageReport(plan *tfjson.Plan) (coverage.CoverageReport, error) { + defer func() { + if r := recover(); r != nil { + log.Printf("[ERROR] panic when producing coverage report: %+v", r) + } + }() + out := coverage.CoverageReport{ Coverages: make(map[coverage.ArmResource]*coverage.Model, 0), } From 597d6388d5ae947fa5a2a07fbb4716c2e1d11868 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Fri, 28 Jul 2023 17:30:42 +0800 Subject: [PATCH 06/10] read index zip --- coverage/expand_test.go | 4 ++-- coverage/index.go | 39 +++++++++++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 0853e4ba..a3377820 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -35,8 +35,8 @@ func TestExpandAll(t *testing.T) { if azureRepoDir == "" { t.Skip("AZURE_REST_REPO_DIR is not set") } - if !strings.HasSuffix(azureRepoDir, "specification/") { - t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/\"") + if !strings.HasSuffix(azureRepoDir, "/specification") { + t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification\"") } t.Logf("azure repo dir: %s", azureRepoDir) diff --git a/coverage/index.go b/coverage/index.go index fbc23767..d30c72b5 100644 --- a/coverage/index.go +++ b/coverage/index.go @@ -1,6 +1,8 @@ package coverage import ( + "archive/zip" + "bytes" "encoding/json" "fmt" "io" @@ -14,7 +16,7 @@ import ( ) const ( - indexFileURL = "https://raw.githubusercontent.com/teowa/azure-rest-api-index-file/main/index.json" + indexFileURL = "https://raw.githubusercontent.com/teowa/azure-rest-api-index-file/main/index.json.zip" azureRepoURL = "https://raw.githubusercontent.com/Azure/azure-rest-api-specs/main/specification/" ) @@ -27,18 +29,38 @@ func GetIndex() (*azidx.Index, error) { resp, err := http.Get(indexFileURL) if err != nil { - return nil, fmt.Errorf("get index file (%v): %+v", indexFileURL, err) + return nil, fmt.Errorf("get index file from %v: %+v", indexFileURL, err) } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("read index file: %+v", err) + return nil, fmt.Errorf("download index file zip: %+v", err) + } + + zipReader, err := zip.NewReader(bytes.NewReader(b), int64(len(b))) + if err != nil { + return nil, fmt.Errorf("read index file zip: %+v", err) + } + + var unzippedIndexBytes []byte + for _, zipFile := range zipReader.File { + if strings.EqualFold(zipFile.Name, "index.json") { + unzippedIndexBytes, err = readZipFile(zipFile) + if err != nil { + return nil, fmt.Errorf("unzip index file: %+v", err) + } + break + } + } + + if len(unzippedIndexBytes) == 0 { + return nil, fmt.Errorf("index file not found in zip") } var index azidx.Index - if err := json.Unmarshal(b, &index); err != nil { + if err := json.Unmarshal(unzippedIndexBytes, &index); err != nil { return nil, fmt.Errorf("unmarshal index file: %+v", err) } indexCache = &index @@ -122,3 +144,12 @@ func GetModelInfoFromIndexRef(ref openapispec.Ref, swaggerRepo string) (*Swagger SwaggerPath: swaggerPath, }, nil } + +func readZipFile(zf *zip.File) ([]byte, error) { + f, err := zf.Open() + if err != nil { + return nil, err + } + defer f.Close() + return io.ReadAll(f) +} From cbac7c7208dd4a43b90e25370467d66a622210f2 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Mon, 31 Jul 2023 18:21:30 +0800 Subject: [PATCH 07/10] add more tests --- coverage/coverage.go | 1 + coverage/expand.go | 24 +++--- coverage/expand_test.go | 144 +++++++++++++++++++++++++++++++++-- coverage/testdata/test1.json | 75 ++++++++++++++++++ coverage/testdata/test2.json | 25 ++++++ report/pass_report.go | 14 +++- 6 files changed, 264 insertions(+), 19 deletions(-) create mode 100644 coverage/testdata/test1.json create mode 100644 coverage/testdata/test2.json diff --git a/coverage/coverage.go b/coverage/coverage.go index add1a025..8e131107 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -28,6 +28,7 @@ type Model struct { TotalCount int `json:"TotalCount,omitempty"` Type *string `json:"Type,omitempty"` Variants *map[string]*Model `json:"Variants,omitempty"` + VariantType *string `json:"VariantType,omitempty"` } func (m *Model) MarkCovered(root interface{}) { diff --git a/coverage/expand.go b/coverage/expand.go index 4af075eb..4ff2573a 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -71,27 +71,20 @@ func getAllOfTable(swaggerPath string) (map[string]map[string]interface{}, error } func Expand(modelName, swaggerPath string) (*Model, error) { + if modelName == "" { + return nil, fmt.Errorf("modelName is empty") + } + doc, err := loadSwagger(swaggerPath) if err != nil { return nil, err } - if modelName == "" { - return nil, nil - } - spec := doc.Spec() modelSchema, ok := spec.Definitions[modelName] if !ok { - _, ok := spec.Parameters[modelName] - if ok { - // https://github.com/Azure/azure-rest-api-specs/blob/fef27735a1c8498d970be905bc45b2e4892fc3b0/specification/vmware/resource-manager/Microsoft.AVS/stable/2021-06-01/vmware.json#L251 - log.Printf("[WARN] Parameter %s is used as a model in %s", modelName, swaggerPath) - return &Model{}, nil - } else { - return nil, fmt.Errorf("%s not found in the definition of %s", modelName, swaggerPath) - } + return nil, fmt.Errorf("%s not found in the definition of %s", modelName, swaggerPath) } output := expandSchema(modelSchema, swaggerPath, modelName, "#", spec, map[string]interface{}{}, map[string]interface{}{}) @@ -228,6 +221,12 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s // the model should be a variant if its allOf contains a discriminator if allOf.Discriminator != nil { output.Discriminator = allOf.Discriminator + + variantName := modelName + if variantNameRaw, ok := input.Extensions[msExtensionDiscriminator]; ok && variantNameRaw != nil { + variantName = variantNameRaw.(string) + } + output.VariantType = &variantName } } @@ -289,6 +288,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s } resolved := expandSchema(schema, swaggerPath, variantModelName, identifier+"{"+variantName+"}", root, resolvedDiscriminator, resolvedModel) + resolved.VariantType = &variantName variants[variantName] = resolved if varVarSet, ok := allOfTable[variantModelName]; ok { for v := range varVarSet { diff --git a/coverage/expand_test.go b/coverage/expand_test.go index a3377820..9a2b18bd 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -3,8 +3,10 @@ package coverage_test import ( "encoding/json" "fmt" + "log" "os" "runtime" + "sort" "strings" "sync" "testing" @@ -29,14 +31,87 @@ func TestExpand_MediaTransform(t *testing.T) { t.Logf("expanded model %s", string(out)) } +func TestExpandLocal_basic(t *testing.T) { + swaggerPath := "./testdata/" + pathRef := jsonreference.MustCreateRef("test1.json#/paths/~1path1/put") + swaggerModel, err := coverage.GetModelInfoFromIndexRef(openapispec.Ref{Ref: pathRef}, swaggerPath) + if err != nil { + t.Fatal(err) + } + + if swaggerModel.ModelName != "pet" { + t.Fatalf("expected modelName pet, got %s", swaggerModel.ModelName) + } + + model, err := coverage.Expand(swaggerModel.ModelName, swaggerModel.SwaggerPath) + if err != nil { + t.Fatal(err) + } + + out, err := json.MarshalIndent(model, "", "\t") + if err != nil { + t.Fatal(err) + } + t.Logf("expanded model %s", string(out)) + + if model.Properties == nil { + t.Fatalf("expected properties not nil") + } + + if _, ok := (*model.Properties)["odata.type"]; !ok { + t.Fatalf("expected properties odata.type string") + } + + if model.Variants == nil { + t.Fatalf("expected variants not nil") + } + + if model.Discriminator == nil || *model.Discriminator != "odata.type" { + t.Fatalf("expected discriminator odata.type") + } + + if model.VariantType == nil || *model.VariantType != "animal.pet" { + t.Fatalf("expected variantType animal.pet") + } + + if _, ok := (*model.Variants)["animal.pet.dog"]; !ok { + t.Fatalf("expected variants dog not nil") + } + + if (*model.Variants)["animal.pet.dog"].Properties == nil { + t.Fatalf("expected variants dog properties not nil") + } + + if (*model.Variants)["animal.pet.dog"].VariantType == nil || *(*model.Variants)["animal.pet.dog"].VariantType != "animal.pet.dog" { + t.Fatalf("expected variants dog variantType animal.pet.dog") + } + + if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["odata.type"]; !ok { + t.Fatalf("expected variants dog properties odata.type string") + } + + if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["name"]; !ok { + t.Fatalf("expected variants dog properties name") + } + + if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["is_barking"]; !ok { + t.Fatalf("expected variants dog properties name") + } + +} + // try to expand all PUT and POST models func TestExpandAll(t *testing.T) { + // AZURE_REST_REPO_DIR="/home/test/go/src/github.com/azure/azure-rest-api-specs/specification" TEST_RESULT_FILE="/home/test/res.json" azureRepoDir := os.Getenv("AZURE_REST_REPO_DIR") if azureRepoDir == "" { t.Skip("AZURE_REST_REPO_DIR is not set") } - if !strings.HasSuffix(azureRepoDir, "/specification") { - t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification\"") + if strings.HasSuffix(azureRepoDir, "specification") { + azureRepoDir += "/" + } + if !strings.HasSuffix(azureRepoDir, "specification/") { + t.Fatalf("AZURE_REST_REPO_DIR must specify the specification folder, e.g., AZURE_REST_REPO_DIR=\"/home/test/go/src/github.com/azure/azure-rest-api-specs/specification/\"") } t.Logf("azure repo dir: %s", azureRepoDir) @@ -72,9 +147,11 @@ func TestExpandAll(t *testing.T) { } index = nil + // expand concurrently refChan := make(chan *jsonreference.Ref) - totalPropCount := 0 + counter := sync.Map{} + var waitGroup sync.WaitGroup waitGroup.Add(runtime.NumCPU()) for i := 0; i < runtime.NumCPU(); i++ { @@ -87,11 +164,24 @@ func TestExpandAll(t *testing.T) { panic(fmt.Errorf("get model info from index ref %s: %+v", ref.String(), err)) } - _, err = coverage.Expand(model.ModelName, model.SwaggerPath) + if model.ModelName == "" { + t.Logf("model not found, skip %s", ref.String()) + continue + } + + expanded, err := coverage.Expand(model.ModelName, model.SwaggerPath) if err != nil { + if strings.Contains(err.Error(), "not found in the definition of") { + // https://github.com/Azure/azure-rest-api-specs/blob/f5cb37608399dd19760b9ef985a707294e32fbda/specification/vmware/resource-manager/Microsoft.AVS/stable/2021-06-01/vmware.json#L247 + t.Logf("model %s not found in the definition, skip %s", model.ModelName, ref.String()) + continue + } panic(fmt.Errorf("process %s, expand %s from %s: %+v", ref.String(), model.ModelName, model.SwaggerPath, err)) } + expanded.CountCoverage() + counter.Store(ref.String(), expanded.TotalCount) + // clean up model = nil ref = nil @@ -101,9 +191,53 @@ func TestExpandAll(t *testing.T) { }(i) } + refList := make([]*jsonreference.Ref, 0) for _, ref := range refMaps { + refList = append(refList, ref) + } + sort.Slice(refList, func(i, j int) bool { + return refList[i].String() < refList[j].String() + }) + for _, ref := range refList { refChan <- ref } + close(refChan) - t.Logf("refs numbers: %d, total prop count: %d", len(refMaps), totalPropCount) + waitGroup.Wait() + + type res struct { + AllRef int `json:"all_ref"` + AllProp int `json:"all_prop"` + AvailRef int `json:"avail_ref"` + Paths map[string]int `json:"paths"` + } + + result := res{ + AllRef: len(refMaps), + AllProp: 0, + AvailRef: 0, + Paths: make(map[string]int), + } + counter.Range(func(key, value interface{}) bool { + result.Paths[key.(string)] = value.(int) + result.AvailRef += 1 + result.AllProp += value.(int) + return true + }) + + t.Logf("total refs count: %d, ref with model: %d, total prop count: %d", result.AllRef, result.AvailRef, result.AllProp) + + b, err := json.MarshalIndent(result, "", "\t") + if err != nil { + log.Fatal(err) + } + + testResultFile := os.Getenv("TEST_RESULT_FILE") + if testResultFile == "" { + t.Log(string(b)) + } else { + if err := os.WriteFile(testResultFile, b, 0644); err != nil { + t.Fatal(err) + } + } } diff --git a/coverage/testdata/test1.json b/coverage/testdata/test1.json new file mode 100644 index 00000000..ce8e7b1e --- /dev/null +++ b/coverage/testdata/test1.json @@ -0,0 +1,75 @@ +{ + "swagger": "2.0", + "info": { + "version": "", + "title": "" + }, + "paths": { + "/path1": { + "put": { + "parameters": [ + { + "$ref": "#/parameters/location" + }, + { + "$ref": "./test2.json#/parameters/version" + }, + { + "$ref": "./test2.json#/parameters/input1" + } + ], + "responses": { + "200": { + "description": "OK" + } + } + } + } + }, + "parameters": { + "location": { + "in": "query", + "name": "location", + "type": "string" + } + }, + "definitions": { + "animal": { + "type": "object", + "discriminator": "odata.type", + "properties": { + "odata.type": { + "type": "string" + } + } + }, + "pet": { + "type": "object", + "allOf": [ + { + "$ref": "#/definitions/animal" + } + ], + "properties": { + "name": { + "type": "string" + } + }, + "x-ms-discriminator-value": "animal.pet" + }, + "dog": { + "type": "object", + "allOf": [ + { + "$ref": "#/definitions/pet" + } + ], + "properties": { + "is_barking": { + "type": "boolean" + } + }, + "x-ms-discriminator-value": "animal.pet.dog" + } + } +} \ No newline at end of file diff --git a/coverage/testdata/test2.json b/coverage/testdata/test2.json new file mode 100644 index 00000000..f7c38872 --- /dev/null +++ b/coverage/testdata/test2.json @@ -0,0 +1,25 @@ +{ + "swagger": "2.0", + "paths": { + + }, + "info": { + "version": "", + "title": "" + }, + "parameters": { + "input1": { + "name": "input", + "in": "body", + "required": true, + "schema": { + "$ref": "./test1.json#/definitions/pet" + } + }, + "version": { + "in": "query", + "name": "version", + "type": "string" + } + } +} \ No newline at end of file diff --git a/report/pass_report.go b/report/pass_report.go index 2bd86206..dc6e9cde 100644 --- a/report/pass_report.go +++ b/report/pass_report.go @@ -117,8 +117,13 @@ func getReport(model *coverage.Model) []string { } if v.Variants != nil { - for variantName, variant := range *v.Variants { - variantKey := fmt.Sprintf("%s{%s}", k, variantName) + for variantType, variant := range *v.Variants { + variantType := variantType + if v.VariantType != nil { + variantType = *v.Item.VariantType + } + variantKey := fmt.Sprintf("%s{%s}", k, variantType) + if variant == nil { // reference to self out = append(out, getChildReport(variantKey, v)) @@ -132,7 +137,12 @@ func getReport(model *coverage.Model) []string { if v.Item != nil && v.Item.Variants != nil { for variantType, variant := range *v.Item.Variants { + variantType := variantType + if v.Item.VariantType != nil { + variantType = *v.Item.VariantType + } variantKey := fmt.Sprintf("%s{%s}", k, variantType) + if variant == nil { // reference to self out = append(out, getChildReport(variantKey, v)) From f2cd8e860d41eebf8a1bd02c0ff3e817635bf864 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:24:12 +0800 Subject: [PATCH 08/10] optimize --- coverage/coverage.go | 26 +++-- coverage/coverage_test.go | 203 ++++++++++++++++++++++++++++-------- coverage/expand.go | 15 ++- coverage/expand_test.go | 213 ++++++++++++++++++++++---------------- report/pass_report.go | 37 +++---- 5 files changed, 326 insertions(+), 168 deletions(-) diff --git a/coverage/coverage.go b/coverage/coverage.go index 8e131107..756441b1 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -70,20 +70,30 @@ func (m *Model) MarkCovered(root interface{}) { case map[string]interface{}: isMatchProperty := true if m.Discriminator != nil && m.Variants != nil { + Loop: for k, v := range value { if k == *m.Discriminator { - variant, ok := (*m.Variants)[v.(string)] - if !ok { - log.Printf("[ERROR] unexpected variant %s in %s\n", v.(string), m.Identifier) + if m.ModelName == v.(string) { + break Loop } - if variant == nil { - break + if m.VariantType != nil && *m.VariantType == v.(string) { + break Loop } + if variant, ok := (*m.Variants)[v.(string)]; ok { + isMatchProperty = false + variant.MarkCovered(value) - isMatchProperty = false - (*m.Variants)[v.(string)].MarkCovered(value) + break + } + for _, variant := range *m.Variants { + if variant.VariantType != nil && *variant.VariantType == v.(string) { + isMatchProperty = false + variant.MarkCovered(value) - break + break Loop + } + } + log.Printf("[ERROR] unexpected variant %s in %s\n", v.(string), m.Identifier) } } } diff --git a/coverage/coverage_test.go b/coverage/coverage_test.go index dda8fc88..62365121 100644 --- a/coverage/coverage_test.go +++ b/coverage/coverage_test.go @@ -59,6 +59,105 @@ func TestCoverage_ResourceGroup(t *testing.T) { } } +func TestCoverage_MachineLearningServicesWorkspacesJobs(t *testing.T) { + tc := testCase{ + name: "MachineLearningServicesWorkspacesJobs", + resourceType: "Microsoft.MachineLearningServices/workspaces/jobs", + apiVersion: "2023-06-01-preview", + apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/rg1/providers/Microsoft.MachineLearningServices/workspaces/works1/jobs/job1", + rawRequest: []string{`{ + "properties": { + "description": "string", + "tags": { + "string": "string" + }, + "properties": { + "string": "string" + }, + "displayName": "string", + "experimentName": "string", + "services": { + "string": { + "jobServiceType": "string", + "port": 1, + "endpoint": "string", + "properties": { + "string": "string" + } + } + }, + "computeId": "string", + "jobType": "Pipeline", + "settings": {}, + "inputs": { + "string": { + "description": "string", + "jobInputType": "literal", + "value": "string" + } + }, + "outputs": { + "string": { + "description": "string", + "jobOutputType": "uri_file", + "mode": "Upload", + "uri": "string" + } + } + } +}`, + }, + } + + model, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } + + expected := 11 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) + } +} + +func TestCoverage_MachineLearningServicesWorkspacesDataVersions(t *testing.T) { + tc := testCase{ + name: "MachineLearningServicesWorkspacesDataVersions", + resourceType: "Microsoft.MachineLearningServices/workspaces/data/versions", + apiVersion: "2023-06-01-preview", + apiPath: "/subscriptions/12345678-1234-9876-4563-123456789012/resourceGroups/rg1/providers/Microsoft.MachineLearningServices/workspaces/works1/data/data1/versions/version1", + rawRequest: []string{`{ + "properties": { + "description": "string", + "tags": { + "string": "string" + }, + "properties": { + "string": "string" + }, + "isArchived": false, + "isAnonymous": false, + "dataUri": "string", + "dataType": "mltable", + "referencedUris": [ + "string" + ] + } +}`, + }, + } + + model, err := testCoverage(t, tc) + if err != nil { + t.Fatalf("process coverage: %+v", err) + } + + expected := 8 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) + } +} + func TestCoverage_DeviceSecurityGroup(t *testing.T) { tc := testCase{ name: "DeviceSecurityGroup", @@ -87,13 +186,14 @@ func TestCoverage_DeviceSecurityGroup(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 5 { - t.Fatalf("expected CoveredCount 5, got %d", model.CoveredCount) + expected := 5 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } func TestCoverage_DataMigrationServiceTasks(t *testing.T) { - // TODO: support cross file discriminator reference, e.g., https://github.com/Azure/azure-rest-api-specs/blob/0ab5469dc0d75594f5747493dcfe8774e22d728f/specification/datamigration/resource-manager/Microsoft.DataMigration/stable/2021-06-30/definitions/ServiceTasks.json#L39 + // Do we need to support cross file discriminator reference? Now seems only DataMigration has this. e.g., https://github.com/Azure/azure-rest-api-specs/blob/0ab5469dc0d75594f5747493dcfe8774e22d728f/specification/datamigration/resource-manager/Microsoft.DataMigration/stable/2021-06-30/definitions/ServiceTasks.json#L39 tc := testCase{ name: "DataMigrationServiceTasks", resourceType: "Microsoft.DataMigration/services/serviceTasks@2021-06-30", @@ -146,8 +246,9 @@ func TestCoverage_DataMigrationTasks(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 8 { - t.Fatalf("expected CoveredCount 8, got %d", model.CoveredCount) + expected := 8 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -231,8 +332,9 @@ func TestCoverage_KeyVault(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 13 { - t.Fatalf("expected CoveredCount 13, got %d", model.CoveredCount) + expected := 13 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -299,8 +401,9 @@ func TestCoverage_StorageAccount(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 24 { - t.Fatalf("expected CoveredCount 24, got %d", model.CoveredCount) + expected := 24 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -368,8 +471,9 @@ func TestCoverage_VM(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 20 { - t.Fatalf("expected CoveredCount 20, got %d", model.CoveredCount) + expected := 20 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -406,8 +510,9 @@ func TestCoverage_VNet(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 4 { - t.Fatalf("expected CoveredCount 4, got %d", model.CoveredCount) + expected := 4 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -731,8 +836,9 @@ func TestCoverage_DataCollectionRule(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 65 { - t.Fatalf("expected CoveredCount 65, got %d", model.CoveredCount) + expected := 65 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -758,8 +864,9 @@ func TestCoverage_WebSite(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 3 { - t.Fatalf("expected CoveredCount 3, got %d", model.CoveredCount) + expected := 3 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } } @@ -848,8 +955,9 @@ func TestCoverage_AKS(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 33 { - t.Fatalf("expected TotalCount 33, got %d", model.CoveredCount) + expected := 33 + if model.CoveredCount != expected { + t.Fatalf("expected TotalCount %d, got %d", expected, model.CoveredCount) } } @@ -947,8 +1055,9 @@ func TestCoverage_CosmosDB(t *testing.T) { t.Fatalf("process coverage: %v", err) } - if model.CoveredCount != 33 { - t.Fatalf("expected CoveredCount 33, got %d", model.CoveredCount) + expected := 33 + if model.CoveredCount != expected { + t.Fatalf("expected CoveredCount %d, got %d", expected, model.CoveredCount) } if model.Properties == nil { @@ -1170,10 +1279,15 @@ func TestCoverage_DataFactoryPipelines(t *testing.T) { }, } - _, err := testCoverage(t, tc) + model, err := testCoverage(t, tc) if err != nil { t.Fatalf("process coverage: %+v", err) } + + expected := 11 + if model.CoveredCount != expected { + t.Fatalf("expected TotalCount %d, got %d", expected, model.CoveredCount) + } } func TestCoverage_DataFactoryLinkedServices(t *testing.T) { @@ -1201,8 +1315,9 @@ func TestCoverage_DataFactoryLinkedServices(t *testing.T) { t.Fatalf("process coverage: %+v", err) } - if model.CoveredCount != 2 { - t.Fatalf("expected TotalCount 2, got %d", model.CoveredCount) + expected := 2 + if model.CoveredCount != expected { + t.Fatalf("expected TotalCount %d, got %d", expected, model.CoveredCount) } if model.Properties == nil { @@ -1233,40 +1348,44 @@ func TestCoverage_DataFactoryLinkedServices(t *testing.T) { t.Fatalf("expected properties variants, got none") } - if v, ok := (*(*model.Properties)["properties"].Variants)["AzureStorage"]; !ok || v == nil { - t.Fatalf("expected properties variant AzureStorage, got none") + if v, ok := (*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"]; !ok || v == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService, got none") + } + + if v := (*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].VariantType; v == nil || *v != "AzureStorage" { + t.Fatalf("expected properties variant AzureStorageLinkedService variant type AzureStorage") } - if (*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties == nil { - t.Fatalf("expected properties variant AzureStorage properties, got none") + if (*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService properties, got none") } - if v, ok := (*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["type"]; !ok || v == nil { - t.Fatalf("expected properties variant AzureStorage type property, got none") + if v, ok := (*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["type"]; !ok || v == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService type property, got none") } - if !(*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["type"].IsAnyCovered { - t.Fatalf("expected properties variant AzureStorage type IsAnyCovered true, got false") + if !(*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["type"].IsAnyCovered { + t.Fatalf("expected properties variant AzureStorageLinkedService type IsAnyCovered true, got false") } - if v, ok := (*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["typeProperties"]; !ok || v == nil { - t.Fatalf("expected properties variant AzureStorage typeProperties property, got none") + if v, ok := (*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["typeProperties"]; !ok || v == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService typeProperties property, got none") } - if !(*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["typeProperties"].IsAnyCovered { - t.Fatalf("expected properties variant AzureStorage typeProperties IsAnyCovered true, got false") + if !(*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["typeProperties"].IsAnyCovered { + t.Fatalf("expected properties variant AzureStorageLinkedService typeProperties IsAnyCovered true, got false") } - if (*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["typeProperties"].Properties == nil { - t.Fatalf("expected properties variant AzureStorage typeProperties properties, got none") + if (*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["typeProperties"].Properties == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService typeProperties properties, got none") } - if v, ok := (*(*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["typeProperties"].Properties)["connectionString"]; !ok || v == nil { - t.Fatalf("expected properties variant AzureStorage typeProperties connectionString property, got none") + if v, ok := (*(*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["typeProperties"].Properties)["connectionString"]; !ok || v == nil { + t.Fatalf("expected properties variant AzureStorageLinkedService typeProperties connectionString property, got none") } - if !(*(*(*(*model.Properties)["properties"].Variants)["AzureStorage"].Properties)["typeProperties"].Properties)["connectionString"].IsAnyCovered { - t.Fatalf("expected properties variant AzureStorage typeProperties connectionString IsAnyCovered true, got false") + if !(*(*(*(*model.Properties)["properties"].Variants)["AzureStorageLinkedService"].Properties)["typeProperties"].Properties)["connectionString"].IsAnyCovered { + t.Fatalf("expected properties variant AzureStorageLinkedService typeProperties connectionString IsAnyCovered true, got false") } } diff --git a/coverage/expand.go b/coverage/expand.go index 4ff2573a..fc14031e 100644 --- a/coverage/expand.go +++ b/coverage/expand.go @@ -17,7 +17,7 @@ const msExtensionDiscriminator = "x-ms-discriminator-value" var ( // {swaggerPath: doc Object} - swaggerCache, _ = lru.New[string, *loads.Document](20) + swaggerCache, _ = lru.New[string, *loads.Document](30) // {swaggerPath: {parentModelName: {childModelName: nil}}} allOfTableCache, _ = lru.New[string, map[string]map[string]interface{}](10) @@ -206,7 +206,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s // expand properties for k, v := range input.Properties { - properties[k] = expandSchema(v, swaggerPath, fmt.Sprintf("%s.%s", modelName, k), identifier+"."+k, root, resolvedDiscriminator, resolvedModel) + properties[k] = expandSchema(v, swaggerPath, fmt.Sprintf("%s.%s", modelName, k), fmt.Sprintf("%s.%s", identifier, k), root, resolvedDiscriminator, resolvedModel) } // expand composition @@ -256,7 +256,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s // expand items if input.Items != nil { - item := expandSchema(*input.Items.Schema, swaggerPath, fmt.Sprintf("%s[]", modelName), identifier+"[]", root, resolvedDiscriminator, resolvedModel) + item := expandSchema(*input.Items.Schema, swaggerPath, fmt.Sprintf("%s[]", modelName), fmt.Sprintf("%s[]", identifier), root, resolvedDiscriminator, resolvedModel) output.Item = item } @@ -273,9 +273,7 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s varSet, ok := allOfTable[modelName] if ok { resolvedDiscriminator[modelName] = nil - variants := map[string]*Model{ - modelName: nil, - } + variants := map[string]*Model{} // level order traverse to find all variants for len(varSet) > 0 { @@ -287,9 +285,10 @@ func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier s variantName = variantNameRaw.(string) } - resolved := expandSchema(schema, swaggerPath, variantModelName, identifier+"{"+variantName+"}", root, resolvedDiscriminator, resolvedModel) + resolved := expandSchema(schema, swaggerPath, variantModelName, fmt.Sprintf("%s{%s}", identifier, variantName), root, resolvedDiscriminator, resolvedModel) resolved.VariantType = &variantName - variants[variantName] = resolved + // in case of https://github.com/Azure/azure-rest-api-specs/issues/25104, use modelName as key + variants[variantModelName] = resolved if varVarSet, ok := allOfTable[variantModelName]; ok { for v := range varVarSet { tempVarSet[v] = nil diff --git a/coverage/expand_test.go b/coverage/expand_test.go index 9a2b18bd..5cbca522 100644 --- a/coverage/expand_test.go +++ b/coverage/expand_test.go @@ -5,11 +5,13 @@ import ( "fmt" "log" "os" + "path/filepath" "runtime" "sort" "strings" "sync" "testing" + "time" "github.com/go-openapi/jsonreference" openapispec "github.com/go-openapi/spec" @@ -31,7 +33,7 @@ func TestExpand_MediaTransform(t *testing.T) { t.Logf("expanded model %s", string(out)) } -func TestExpandLocal_basic(t *testing.T) { +func TestExpand_referToChildVariant(t *testing.T) { swaggerPath := "./testdata/" pathRef := jsonreference.MustCreateRef("test1.json#/paths/~1path1/put") swaggerModel, err := coverage.GetModelInfoFromIndexRef(openapispec.Ref{Ref: pathRef}, swaggerPath) @@ -74,35 +76,35 @@ func TestExpandLocal_basic(t *testing.T) { t.Fatalf("expected variantType animal.pet") } - if _, ok := (*model.Variants)["animal.pet.dog"]; !ok { + if _, ok := (*model.Variants)["dog"]; !ok { t.Fatalf("expected variants dog not nil") } - if (*model.Variants)["animal.pet.dog"].Properties == nil { + if (*model.Variants)["dog"].Properties == nil { t.Fatalf("expected variants dog properties not nil") } - if (*model.Variants)["animal.pet.dog"].VariantType == nil || *(*model.Variants)["animal.pet.dog"].VariantType != "animal.pet.dog" { + if (*model.Variants)["dog"].VariantType == nil || *(*model.Variants)["dog"].VariantType != "animal.pet.dog" { t.Fatalf("expected variants dog variantType animal.pet.dog") } - if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["odata.type"]; !ok { + if _, ok := (*(*model.Variants)["dog"].Properties)["odata.type"]; !ok { t.Fatalf("expected variants dog properties odata.type string") } - if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["name"]; !ok { + if _, ok := (*(*model.Variants)["dog"].Properties)["name"]; !ok { t.Fatalf("expected variants dog properties name") } - if _, ok := (*(*model.Variants)["animal.pet.dog"].Properties)["is_barking"]; !ok { + if _, ok := (*(*model.Variants)["dog"].Properties)["is_barking"]; !ok { t.Fatalf("expected variants dog properties name") } } -// try to expand all PUT and POST models +// try to expand all PUT and POST models twice, and ensure result is the same +// AZURE_REST_REPO_DIR="/home/test/go/src/github.com/azure/azure-rest-api-specs/specification" TEST_RESULT_FILE="/home/test/" func TestExpandAll(t *testing.T) { - // AZURE_REST_REPO_DIR="/home/test/go/src/github.com/azure/azure-rest-api-specs/specification" TEST_RESULT_FILE="/home/test/res.json" azureRepoDir := os.Getenv("AZURE_REST_REPO_DIR") if azureRepoDir == "" { t.Skip("AZURE_REST_REPO_DIR is not set") @@ -115,12 +117,95 @@ func TestExpandAll(t *testing.T) { } t.Logf("azure repo dir: %s", azureRepoDir) + + testResultFilePath := os.Getenv("TEST_RESULT_PATH") + + res := testExpandAll(t, azureRepoDir, resultFile(testResultFilePath)) + res2 := testExpandAll(t, azureRepoDir, resultFile(testResultFilePath)) + if res.AllRef != res2.AllRef || res.AllProp != res2.AllProp || res.AvailRef != res2.AvailRef { + t.Fatalf("result not equal, res1: %+v, res2: %+v", res, res2) + } +} + +func resultFile(testResultFilePath string) string { + return filepath.Join(testResultFilePath, fmt.Sprintf("res_%s.json", time.Now().Format(time.RFC3339))) +} + +type result struct { + AllRef int `json:"all_ref"` + AllProp int `json:"all_prop"` + AvailRef int `json:"avail_ref"` + Paths map[string]int `json:"paths"` +} + +func testExpandAll(t *testing.T, azureRepoDir, testResultPath string) result { + refList := getRefList(t) + for _, ref := range refList { + t.Logf("ref: %s", ref.String()) + } + + // expand concurrently + res := result{ + AllRef: len(refList), + AllProp: 0, + AvailRef: 0, + Paths: make(map[string]int), + } + refChan := make(chan jsonreference.Ref) + lock := sync.RWMutex{} + + var waitGroup sync.WaitGroup + waitGroup.Add(runtime.NumCPU()) + for i := 0; i < runtime.NumCPU(); i++ { + go func(i int) { + for ref := range refChan { + t.Logf("%v ref: %v", i, ref.String()) + count := expandRef(ref, azureRepoDir, t) + + lock.Lock() + if count != 0 { + res.Paths[ref.String()] = count + res.AllProp += count + res.AvailRef += 1 + } + lock.Unlock() + } + waitGroup.Done() + }(i) + } + + for _, ref := range refList { + refChan <- ref + } + close(refChan) + + waitGroup.Wait() + + t.Logf("total refs count: %d, ref with model: %d, total prop count: %d", res.AllRef, res.AvailRef, res.AllProp) + + b, err := json.MarshalIndent(res, "", "\t") + if err != nil { + log.Fatal(err) + } + + if testResultPath == "" { + t.Log(string(b)) + } else { + if err := os.WriteFile(testResultPath, b, 0644); err != nil { + t.Fatal(err) + } + } + + return res +} + +func getRefList(t *testing.T) []jsonreference.Ref { index, err := coverage.GetIndex() if err != nil { t.Fatal(err) } - refMaps := make(map[string]*jsonreference.Ref, 0) + refMaps := make(map[string]jsonreference.Ref, 0) for resourceProvider, versionRaw := range index.ResourceProviders { for version, methodRaw := range versionRaw { for operationKind, resourceTypeRaw := range methodRaw { @@ -133,111 +218,55 @@ func TestExpandAll(t *testing.T) { } for action, operationRefs := range operationInfo.Actions { for pathPattern, ref := range operationRefs { - t.Logf("%s %s %s %s %s %s", resourceProvider, version, operationKind, resourceType, action, pathPattern) - refMaps[ref.String()] = &ref + t.Logf("%s %s %s %s %s %s: %s", resourceProvider, version, operationKind, resourceType, action, pathPattern, ref.String()) + refMaps[ref.String()] = ref } } for pathPattern, ref := range operationInfo.OperationRefs { - t.Logf("%s %s %s %s %s", resourceProvider, version, operationKind, resourceType, pathPattern) - refMaps[ref.String()] = &ref + t.Logf("%s %s %s %s %s: %s", resourceProvider, version, operationKind, resourceType, pathPattern, ref.String()) + refMaps[ref.String()] = ref } } } } } - index = nil - - // expand concurrently - refChan := make(chan *jsonreference.Ref) - - counter := sync.Map{} - - var waitGroup sync.WaitGroup - waitGroup.Add(runtime.NumCPU()) - for i := 0; i < runtime.NumCPU(); i++ { - go func(i int) { - for ref := range refChan { - t.Logf("%v ref: %v", i, ref.String()) - - model, err := coverage.GetModelInfoFromIndexRef(openapispec.Ref{Ref: *ref}, azureRepoDir) - if err != nil { - panic(fmt.Errorf("get model info from index ref %s: %+v", ref.String(), err)) - } - - if model.ModelName == "" { - t.Logf("model not found, skip %s", ref.String()) - continue - } - - expanded, err := coverage.Expand(model.ModelName, model.SwaggerPath) - if err != nil { - if strings.Contains(err.Error(), "not found in the definition of") { - // https://github.com/Azure/azure-rest-api-specs/blob/f5cb37608399dd19760b9ef985a707294e32fbda/specification/vmware/resource-manager/Microsoft.AVS/stable/2021-06-01/vmware.json#L247 - t.Logf("model %s not found in the definition, skip %s", model.ModelName, ref.String()) - continue - } - panic(fmt.Errorf("process %s, expand %s from %s: %+v", ref.String(), model.ModelName, model.SwaggerPath, err)) - } - - expanded.CountCoverage() - counter.Store(ref.String(), expanded.TotalCount) - - // clean up - model = nil - ref = nil - } - - waitGroup.Done() - }(i) - } - refList := make([]*jsonreference.Ref, 0) + refList := make([]jsonreference.Ref, 0) for _, ref := range refMaps { refList = append(refList, ref) } sort.Slice(refList, func(i, j int) bool { return refList[i].String() < refList[j].String() }) - for _, ref := range refList { - refChan <- ref + + for i, ref := range refList { + t.Logf("refList ref %d: %s", i, ref.String()) } - close(refChan) - waitGroup.Wait() + return refList +} - type res struct { - AllRef int `json:"all_ref"` - AllProp int `json:"all_prop"` - AvailRef int `json:"avail_ref"` - Paths map[string]int `json:"paths"` +func expandRef(ref jsonreference.Ref, azureRepoDir string, t *testing.T) int { + model, err := coverage.GetModelInfoFromIndexRef(openapispec.Ref{Ref: ref}, azureRepoDir) + if err != nil { + panic(fmt.Errorf("get model info from index ref %s: %+v", ref.String(), err)) } - result := res{ - AllRef: len(refMaps), - AllProp: 0, - AvailRef: 0, - Paths: make(map[string]int), + if model.ModelName == "" { + t.Logf("model not found, skip %s", ref.String()) + return 0 } - counter.Range(func(key, value interface{}) bool { - result.Paths[key.(string)] = value.(int) - result.AvailRef += 1 - result.AllProp += value.(int) - return true - }) - - t.Logf("total refs count: %d, ref with model: %d, total prop count: %d", result.AllRef, result.AvailRef, result.AllProp) - b, err := json.MarshalIndent(result, "", "\t") + expanded, err := coverage.Expand(model.ModelName, model.SwaggerPath) if err != nil { - log.Fatal(err) - } - - testResultFile := os.Getenv("TEST_RESULT_FILE") - if testResultFile == "" { - t.Log(string(b)) - } else { - if err := os.WriteFile(testResultFile, b, 0644); err != nil { - t.Fatal(err) + if strings.Contains(err.Error(), "not found in the definition of") { + // https://github.com/Azure/azure-rest-api-specs/blob/f5cb37608399dd19760b9ef985a707294e32fbda/specification/vmware/resource-manager/Microsoft.AVS/stable/2021-06-01/vmware.json#L247 + t.Logf("model %s not found in the definition, skip %s", model.ModelName, ref.String()) + return 0 } + panic(fmt.Errorf("process %s, expand %s from %s: %+v", ref.String(), model.ModelName, model.SwaggerPath, err)) } + + expanded.CountCoverage() + return expanded.TotalCount } diff --git a/report/pass_report.go b/report/pass_report.go index dc6e9cde..713d9f32 100644 --- a/report/pass_report.go +++ b/report/pass_report.go @@ -117,38 +117,39 @@ func getReport(model *coverage.Model) []string { } if v.Variants != nil { + variantType := v.ModelName + if v.VariantType != nil { + variantType = *v.VariantType + } + variantKey := fmt.Sprintf("%s{%s}", k, variantType) + + out = append(out, getChildReport(variantKey, v)) + for variantType, variant := range *v.Variants { variantType := variantType - if v.VariantType != nil { - variantType = *v.Item.VariantType + if variant.VariantType != nil { + variantType = *variant.VariantType } variantKey := fmt.Sprintf("%s{%s}", k, variantType) - - if variant == nil { - // reference to self - out = append(out, getChildReport(variantKey, v)) - continue - } - out = append(out, getChildReport(variantKey, variant)) } continue } if v.Item != nil && v.Item.Variants != nil { + variantType := v.Item.ModelName + if v.Item.VariantType != nil { + variantType = *v.Item.VariantType + } + variantKey := fmt.Sprintf("%s{%s}", k, variantType) + out = append(out, getChildReport(variantKey, v)) + for variantType, variant := range *v.Item.Variants { variantType := variantType - if v.Item.VariantType != nil { - variantType = *v.Item.VariantType + if variant.VariantType != nil { + variantType = *variant.VariantType } variantKey := fmt.Sprintf("%s{%s}", k, variantType) - - if variant == nil { - // reference to self - out = append(out, getChildReport(variantKey, v)) - continue - } - out = append(out, getChildReport(variantKey, variant)) } continue From 02e9b9decb5cd1ed316f177a0ecedb38ee589427 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:57:46 +0800 Subject: [PATCH 09/10] fix --- coverage/coverage.go | 4 ++-- coverage/testdata/test1.json | 4 ++-- coverage/testdata/test2.json | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/coverage/coverage.go b/coverage/coverage.go index 756441b1..be7d317a 100644 --- a/coverage/coverage.go +++ b/coverage/coverage.go @@ -74,10 +74,10 @@ func (m *Model) MarkCovered(root interface{}) { for k, v := range value { if k == *m.Discriminator { if m.ModelName == v.(string) { - break Loop + break } if m.VariantType != nil && *m.VariantType == v.(string) { - break Loop + break } if variant, ok := (*m.Variants)[v.(string)]; ok { isMatchProperty = false diff --git a/coverage/testdata/test1.json b/coverage/testdata/test1.json index ce8e7b1e..d77061c3 100644 --- a/coverage/testdata/test1.json +++ b/coverage/testdata/test1.json @@ -55,7 +55,7 @@ "type": "string" } }, - "x-ms-discriminator-value": "animal.pet" + "x-ms-discriminator-value": "animal.pet" }, "dog": { "type": "object", @@ -72,4 +72,4 @@ "x-ms-discriminator-value": "animal.pet.dog" } } -} \ No newline at end of file +} diff --git a/coverage/testdata/test2.json b/coverage/testdata/test2.json index f7c38872..e92a9093 100644 --- a/coverage/testdata/test2.json +++ b/coverage/testdata/test2.json @@ -1,8 +1,6 @@ { "swagger": "2.0", - "paths": { - - }, + "paths": {}, "info": { "version": "", "title": "" @@ -22,4 +20,4 @@ "type": "string" } } -} \ No newline at end of file +} From 110817b8f5b6d660cf08e4bd1d2e06860fbe84e7 Mon Sep 17 00:00:00 2001 From: teowa <104055472+teowa@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:11:25 +0800 Subject: [PATCH 10/10] fix lint --- resource/data_source_test.go | 3 +-- resource/resource_test.go | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/resource/data_source_test.go b/resource/data_source_test.go index 35e6590b..09589dd5 100644 --- a/resource/data_source_test.go +++ b/resource/data_source_test.go @@ -15,8 +15,7 @@ func TestDataSource_NewDataSourceFromExample(t *testing.T) { } if r == nil { t.Fatal("expect valid resource, but got nil") - } - if r.ApiVersion != "2020-06-01" { + } else if r.ApiVersion != "2020-06-01" { t.Fatalf("expect api-version 2020-06-01, but got %s", r.ApiVersion) } diff --git a/resource/resource_test.go b/resource/resource_test.go index 8b784851..600b70bb 100644 --- a/resource/resource_test.go +++ b/resource/resource_test.go @@ -15,8 +15,7 @@ func TestResource_NewResourceFromExample(t *testing.T) { } if r == nil { t.Fatal("expect valid resource, but got nil") - } - if r.ApiVersion != "2020-06-01" { + } else if r.ApiVersion != "2020-06-01" { t.Fatalf("expect api-version 2020-06-01, but got %s", r.ApiVersion) }