From 154b728b06fad7e852155224fda4075ff3c00cb8 Mon Sep 17 00:00:00 2001 From: Eric Lin <38420555+Ezzahhh@users.noreply.github.com> Date: Sun, 29 Dec 2024 23:42:56 +1000 Subject: [PATCH] feat: terminate ssm tunnel Signed-off-by: Eric Lin <38420555+Ezzahhh@users.noreply.github.com> --- internal/provider/ephemeral_ssm.go | 30 ++++++++++++++++++++++++++++-- internal/ssm/tunnel.go | 12 ++++++++++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/internal/provider/ephemeral_ssm.go b/internal/provider/ephemeral_ssm.go index 7bd2330..d9a8049 100644 --- a/internal/provider/ephemeral_ssm.go +++ b/internal/provider/ephemeral_ssm.go @@ -5,6 +5,9 @@ import ( "fmt" "strconv" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + aws_ssm "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/dfns/terraform-provider-tunnel/internal/ssm" "github.com/hashicorp/terraform-plugin-framework/ephemeral" "github.com/hashicorp/terraform-plugin-framework/ephemeral/schema" @@ -94,7 +97,7 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp data.LocalHost = types.StringValue("localhost") data.LocalPort = types.Int64Value(int64(localPort)) - cmd, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{ + forkResult, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{ SSMRegion: data.SSMRegion.ValueString(), SSMInstance: data.SSMInstance.ValueString(), TargetHost: data.TargetHost.ValueString(), @@ -108,7 +111,9 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp // Save data into Terraform state resp.Diagnostics.Append(resp.Result.Set(ctx, &data)...) - resp.Private.SetKey(ctx, "tunnel_pid", []byte(strconv.Itoa(cmd.Process.Pid))) + resp.Private.SetKey(ctx, "tunnel_pid", []byte(strconv.Itoa(forkResult.Command.Process.Pid))) + resp.Private.SetKey(ctx, "session_id", []byte(forkResult.Session.SessionId)) + resp.Private.SetKey(ctx, "ssm_region", []byte(data.SSMRegion.ValueString())) } func (d *SSMEphemeral) Close(ctx context.Context, req ephemeral.CloseRequest, resp *ephemeral.CloseResponse) { @@ -129,4 +134,25 @@ func (d *SSMEphemeral) Close(ctx context.Context, req ephemeral.CloseRequest, re resp.Diagnostics.AddError("Failed to terminate tunnel process", fmt.Sprintf("Error: %s", err)) return } + + sessionID, _ := req.Private.GetKey(ctx, "session_id") + ssmRegion, _ := req.Private.GetKey(ctx, "ssm_region") + if len(sessionID) > 0 { + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + resp.Diagnostics.AddError("Failed to load AWS config", fmt.Sprintf("Error: %s", err)) + return + } + awsCfg.Region = string(ssmRegion) + + ssmClient := aws_ssm.NewFromConfig(awsCfg) + + _, err = ssmClient.TerminateSession(ctx, &aws_ssm.TerminateSessionInput{ + SessionId: aws.String(string(sessionID)), + }) + if err != nil { + resp.Diagnostics.AddError("Failed to terminate SSM session", fmt.Sprintf("Error: %s", err)) + return + } + } } diff --git a/internal/ssm/tunnel.go b/internal/ssm/tunnel.go index c9083cb..e174df9 100644 --- a/internal/ssm/tunnel.go +++ b/internal/ssm/tunnel.go @@ -17,6 +17,11 @@ import ( ps "github.com/shirou/gopsutil/v4/process" ) +type ForkRemoteResult struct { + Command *exec.Cmd + Session SessionParams +} + func GetEndpoint(ctx context.Context, region string) (string, error) { resolver := ssm.NewDefaultEndpointResolverV2() endpoint, err := resolver.ResolveEndpoint(ctx, ssm.EndpointParameters{ @@ -58,7 +63,7 @@ func WatchProcess(pid string) (err error) { return nil } -func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*exec.Cmd, error) { +func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*ForkRemoteResult, error) { // First we start a session using AWS SDK // see https://github.com/aws/aws-cli/blob/master/awscli/customizations/sessionmanager.py#L104 sessionParams, err := StartTunnelSession(ctx, cfg) @@ -97,7 +102,10 @@ func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*exec.Cmd, error) return nil, err } - return cmd, nil + return &ForkRemoteResult{ + Command: cmd, + Session: sessionParams, + }, nil } func StartRemoteTunnel(ctx context.Context, cfg TunnelConfig, parentPid string) (err error) {