diff --git a/.changelog/20550.txt b/.changelog/20550.txt new file mode 100644 index 00000000000..4e282fc8eaf --- /dev/null +++ b/.changelog/20550.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where concurrent mount and unmount operations could unstage volumes needed by another allocation +``` diff --git a/client/pluginmanager/csimanager/volume.go b/client/pluginmanager/csimanager/volume.go index 6396dfbad1a..18251afcb1e 100644 --- a/client/pluginmanager/csimanager/volume.go +++ b/client/pluginmanager/csimanager/volume.go @@ -253,6 +253,10 @@ func (v *volumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, logger := v.logger.With("volume_id", vol.ID, "alloc_id", alloc.ID) ctx = hclog.WithContext(ctx, logger) + // Claim before we stage/publish to prevent interleaved Unmount for another + // alloc from unstaging between stage/publish steps below + v.usageTracker.Claim(alloc.ID, vol.ID, vol.Namespace, usage) + if v.requiresStaging { err = v.stageVolume(ctx, vol, usage, publishContext) } @@ -261,10 +265,6 @@ func (v *volumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, mountInfo, err = v.publishVolume(ctx, vol, alloc, usage, publishContext) } - if err == nil { - v.usageTracker.Claim(alloc.ID, vol.ID, vol.Namespace, usage) - } - event := structs.NewNodeEvent(). SetSubsystem(structs.NodeEventSubsystemStorage). SetMessage("Mount volume"). @@ -274,6 +274,7 @@ func (v *volumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, } else { event.AddDetail("success", "false") event.AddDetail("error", err.Error()) + v.usageTracker.Free(alloc.ID, vol.ID, vol.Namespace, usage) } v.eventer(event) diff --git a/client/pluginmanager/csimanager/volume_test.go b/client/pluginmanager/csimanager/volume_test.go index 1b8fd4a697c..1138d0e1f91 100644 --- a/client/pluginmanager/csimanager/volume_test.go +++ b/client/pluginmanager/csimanager/volume_test.go @@ -9,7 +9,9 @@ import ( "os" "runtime" "testing" + "time" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/mount" "github.com/hashicorp/nomad/helper/testlog" @@ -17,6 +19,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" csifake "github.com/hashicorp/nomad/plugins/csi/fake" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -526,3 +529,71 @@ func TestVolumeManager_MountVolumeEvents(t *testing.T) { require.Equal(t, "vol", e.Details["volume_id"]) require.Equal(t, "true", e.Details["success"]) } + +// TestVolumeManager_InterleavedStaging tests that a volume cannot be unstaged +// if another alloc has staged but not yet published +func TestVolumeManager_InterleavedStaging(t *testing.T) { + ci.Parallel(t) + + tmpPath := t.TempDir() + csiFake := &csifake.Client{} + + logger := testlog.HCLogger(t) + ctx := hclog.WithContext(context.Background(), logger) + + manager := newVolumeManager(logger, + func(e *structs.NodeEvent) {}, csiFake, + tmpPath, tmpPath, true, "i-example") + + alloc0, alloc1 := mock.Alloc(), mock.Alloc() + vol := &structs.CSIVolume{ID: "vol", Namespace: "ns"} + usage := &UsageOptions{ + AccessMode: structs.CSIVolumeAccessModeMultiNodeMultiWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + } + pubCtx := map[string]string{} + + // first alloc has previously claimed the volume + manager.usageTracker.Claim(alloc0.ID, vol.ID, vol.Namespace, usage) + + alloc0WaitCh := make(chan struct{}) + alloc1WaitCh := make(chan struct{}) + + // this goroutine simulates MountVolume, but with control over interleaving + // by waiting for the other alloc to check if should unstage before trying + // to publish + manager.usageTracker.Claim(alloc1.ID, vol.ID, vol.Namespace, usage) + must.NoError(t, manager.stageVolume(ctx, vol, usage, pubCtx)) + + go func() { + defer close(alloc1WaitCh) + <-alloc0WaitCh + _, err := manager.publishVolume(ctx, vol, alloc1, usage, pubCtx) + must.NoError(t, err) + }() + + must.NoError(t, manager.UnmountVolume(ctx, vol.Namespace, vol.ID, "foo", alloc0.ID, usage)) + close(alloc0WaitCh) + + testTimeoutCtx, cancel := context.WithTimeout(context.TODO(), time.Second) + t.Cleanup(cancel) + + select { + case <-alloc1WaitCh: + case <-testTimeoutCtx.Done(): + t.Fatal("test timed out") + } + + key := volumeUsageKey{ + id: vol.ID, + ns: vol.Namespace, + usageOpts: *usage, + } + + manager.usageTracker.stateMu.Lock() + t.Cleanup(manager.usageTracker.stateMu.Unlock) + must.Eq(t, []string{alloc1.ID}, manager.usageTracker.state[key]) + + must.Eq(t, 1, csiFake.NodeUnpublishVolumeCallCount, must.Sprint("expected 1 unpublish call")) + must.Eq(t, 0, csiFake.NodeUnstageVolumeCallCount, must.Sprint("expected no unstage call")) +}