Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardfeng-db committed Oct 22, 2024
1 parent d9d7384 commit cabad22
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 30 deletions.
4 changes: 2 additions & 2 deletions internal/providers/pluginfw/pluginfw.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ type DatabricksProviderPluginFramework struct {
var _ provider.Provider = (*DatabricksProviderPluginFramework)(nil)

func (p *DatabricksProviderPluginFramework) Resources(ctx context.Context) []func() resource.Resource {
return getPluginFrameworkResourcesToRegister()
return getPluginFrameworkResourcesToRegister(ctx)
}

func (p *DatabricksProviderPluginFramework) DataSources(ctx context.Context) []func() datasource.DataSource {
return getPluginFrameworkDataSourcesToRegister()
return getPluginFrameworkDataSourcesToRegister(ctx)
}

func (p *DatabricksProviderPluginFramework) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
Expand Down
48 changes: 28 additions & 20 deletions internal/providers/pluginfw/pluginfw_rollout_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,34 @@ var onboardedDataSources = []func() datasource.DataSource{
}

// GetUseSdkV2DataSources is a helper function to get name of resources that should use SDK V2 instead of plugin framework
func getUseSdkV2Resources() []string {
func getUseSdkV2Resources(ctx context.Context) []string {
useSdkV2 := os.Getenv("USE_SDK_V2_RESOURCES")
if useSdkV2 == "" {
return []string{}
useSdkV2Ctx := ctx.Value("USE_SDK_V2_RESOURCES")
combinedNames := ""
if useSdkV2 != "" && useSdkV2Ctx != "" {
combinedNames = useSdkV2 + "," + useSdkV2Ctx.(string)
} else {
combinedNames = useSdkV2 + useSdkV2Ctx.(string)
}
return strings.Split(useSdkV2, ",")
return strings.Split(combinedNames, ",")
}

// GetUseSdkV2DataSources is a helper function to get name of data sources that should use SDK V2 instead of plugin framework
func getUseSdkV2DataSources() []string {
func getUseSdkV2DataSources(ctx context.Context) []string {
useSdkV2 := os.Getenv("USE_SDK_V2_DATA_SOURCES")
if useSdkV2 == "" {
return []string{}
useSdkV2Ctx := ctx.Value("USE_SDK_V2_DATA_SOURCES")
combinedNames := ""
if useSdkV2 != "" && useSdkV2Ctx != "" {
combinedNames = useSdkV2 + "," + useSdkV2Ctx.(string)
} else {
combinedNames = useSdkV2 + useSdkV2Ctx.(string)
}
return strings.Split(useSdkV2, ",")
return strings.Split(combinedNames, ",")
}

// Helper function to check if a resource should use be in SDK V2 instead of plugin framework
func shouldUseSdkV2Resource(resourceName string) bool {
useSdkV2Resources := getUseSdkV2Resources()
func shouldUseSdkV2Resource(ctx context.Context, resourceName string) bool {
useSdkV2Resources := getUseSdkV2Resources(ctx)
for _, sdkV2Resource := range useSdkV2Resources {
if resourceName == sdkV2Resource {
return true
Expand All @@ -74,8 +82,8 @@ func shouldUseSdkV2Resource(resourceName string) bool {
}

// Helper function to check if a data source should use be in SDK V2 instead of plugin framework
func shouldUseSdkV2DataSource(dataSourceName string) bool {
sdkV2DataSources := getUseSdkV2DataSources()
func shouldUseSdkV2DataSource(ctx context.Context, dataSourceName string) bool {
sdkV2DataSources := getUseSdkV2DataSources(ctx)
for _, sdkV2DataSource := range sdkV2DataSources {
if dataSourceName == sdkV2DataSource {
return true
Expand All @@ -85,13 +93,13 @@ func shouldUseSdkV2DataSource(dataSourceName string) bool {
}

// getPluginFrameworkResourcesToRegister is a helper function to get the list of resources that are migrated away from sdkv2 to plugin framework
func getPluginFrameworkResourcesToRegister() []func() resource.Resource {
func getPluginFrameworkResourcesToRegister(ctx context.Context) []func() resource.Resource {
var resources []func() resource.Resource

// Loop through the map and add resources if they're not specifically marked to use the SDK V2
for _, resourceFunc := range migratedResources {
name := getResourceName(resourceFunc)
if !shouldUseSdkV2Resource(name) {
if !shouldUseSdkV2Resource(ctx, name) {
resources = append(resources, resourceFunc)
}
}
Expand All @@ -100,13 +108,13 @@ func getPluginFrameworkResourcesToRegister() []func() resource.Resource {
}

// getPluginFrameworkDataSourcesToRegister is a helper function to get the list of data sources that are migrated away from sdkv2 to plugin framework
func getPluginFrameworkDataSourcesToRegister() []func() datasource.DataSource {
func getPluginFrameworkDataSourcesToRegister(ctx context.Context) []func() datasource.DataSource {
var dataSources []func() datasource.DataSource

// Loop through the map and add data sources if they're not specifically marked to use the SDK V2
for _, dataSourceFunc := range migratedDataSources {
name := getDataSourceName(dataSourceFunc)
if !shouldUseSdkV2DataSource(name) {
if !shouldUseSdkV2DataSource(ctx, name) {
dataSources = append(dataSources, dataSourceFunc)
}
}
Expand All @@ -127,23 +135,23 @@ func getDataSourceName(dataSourceFunc func() datasource.DataSource) string {
}

// GetSdkV2ResourcesToRemove is a helper function to get the list of resources that are migrated away from sdkv2 to plugin framework
func GetSdkV2ResourcesToRemove() []string {
func GetSdkV2ResourcesToRemove(ctx context.Context) []string {
resourcesToRemove := []string{}
for _, resourceFunc := range migratedResources {
name := getResourceName(resourceFunc)
if !shouldUseSdkV2Resource(name) {
if !shouldUseSdkV2Resource(ctx, name) {
resourcesToRemove = append(resourcesToRemove, name)
}
}
return resourcesToRemove
}

// GetSdkV2DataSourcesToRemove is a helper function to get the list of data sources that are migrated away from sdkv2 to plugin framework
func GetSdkV2DataSourcesToRemove() []string {
func GetSdkV2DataSourcesToRemove(ctx context.Context) []string {
dataSourcesToRemove := []string{}
for _, dataSourceFunc := range migratedDataSources {
name := getDataSourceName(dataSourceFunc)
if !shouldUseSdkV2DataSource(name) {
if !shouldUseSdkV2DataSource(ctx, name) {
dataSourcesToRemove = append(dataSourcesToRemove, name)
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func GetProviderServer(ctx context.Context, options ...ServerOption) (tfprotov6.
}
sdkPluginProvider := serverOptions.sdkV2Provider
if sdkPluginProvider == nil {
sdkPluginProvider = sdkv2.DatabricksProvider()
sdkPluginProvider = sdkv2.DatabricksProvider(ctx)
}
pluginFrameworkProvider := serverOptions.pluginFrameworkProvider
if pluginFrameworkProvider == nil {
pluginFrameworkProvider = pluginfw.GetDatabricksProviderPluginFramework()
}

upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer(
context.Background(),
ctx,
sdkPluginProvider.GRPCProvider,
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/providers/providers_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (pf providerFixture) configureProviderAndReturnClient_SDKv2(t *testing.T) (
for k, v := range pf.env {
t.Setenv(k, v)
}
p := sdkv2.DatabricksProvider()
p := sdkv2.DatabricksProvider(context.Background())
ctx := context.Background()
diags := p.Configure(ctx, terraform.NewResourceConfigRaw(pf.rawConfigSDKv2()))
if len(diags) > 0 {
Expand Down
6 changes: 3 additions & 3 deletions internal/providers/sdkv2/sdkv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func init() {
}

// DatabricksProvider returns the entire terraform provider object
func DatabricksProvider() *schema.Provider {
func DatabricksProvider(ctx context.Context) *schema.Provider {
dataSourceMap := map[string]*schema.Resource{ // must be in alphabetical order
"databricks_aws_crossaccount_policy": aws.DataAwsCrossaccountPolicy().ToResource(),
"databricks_aws_assume_role_policy": aws.DataAwsAssumeRolePolicy().ToResource(),
Expand Down Expand Up @@ -228,11 +228,11 @@ func DatabricksProvider() *schema.Provider {
}

// Remove the resources and data sources that are being migrated to plugin framework
for _, dataSourceToRemove := range pluginfw.GetSdkV2DataSourcesToRemove() {
for _, dataSourceToRemove := range pluginfw.GetSdkV2DataSourcesToRemove(ctx) {
delete(dataSourceMap, dataSourceToRemove)
}

for _, resourceToRemove := range pluginfw.GetSdkV2ResourcesToRemove() {
for _, resourceToRemove := range pluginfw.GetSdkV2ResourcesToRemove(ctx) {
delete(resourceMap, resourceToRemove)
}

Expand Down
3 changes: 2 additions & 1 deletion internal/providers/sdkv2/tests/coverage_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"context"
"fmt"
"io"
"io/fs"
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestCoverageReport(t *testing.T) {
files, err := recursiveChildren("..")
assert.NoError(t, err)

p := sdkv2.DatabricksProvider()
p := sdkv2.DatabricksProvider(context.Background())
var cr CoverageReport
var longestResourceName, longestFieldName int

Expand Down
3 changes: 2 additions & 1 deletion internal/providers/sdkv2/tests/generate_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"context"
"fmt"
"go/ast"
"go/parser"
Expand Down Expand Up @@ -234,7 +235,7 @@ func TestGenerateTestCodeStubs(t *testing.T) {
t.Logf("Got %d unit tests in total. %v",
len(funcs), resourceTestStub{})
t.Skip()
p := sdkv2.DatabricksProvider()
p := sdkv2.DatabricksProvider(context.Background())
for name, resource := range p.ResourcesMap {
if name != "databricks_group_instance_profile" {
continue
Expand Down

0 comments on commit cabad22

Please sign in to comment.