diff --git a/pkg/stub/stub.go b/pkg/stub/stub.go index 3cc837b0..c3b886d3 100644 --- a/pkg/stub/stub.go +++ b/pkg/stub/stub.go @@ -137,7 +137,9 @@ type PostUpdateContainerInterface interface { // Stub is the interface the stub provides for the plugin implementation. type Stub interface { - // Run the plugin. Starts the plugin then waits for an error or the plugin to stop + // Run starts the plugin then waits for the plugin service to exit, either due to a + // critical error or an explicit call to Stop(). Once Run() returns, the plugin can be + // restarted by calling Run() or Start() again. Run(context.Context) error // Start the plugin. Start(context.Context) error @@ -255,7 +257,6 @@ type stub struct { rpcs *ttrpc.Server rpcc *ttrpc.Client runtime api.RuntimeService - closeOnce sync.Once started bool doneC chan struct{} srvErrC chan error @@ -288,7 +289,6 @@ func New(p interface{}, opts ...Option) (Stub, error) { idx: os.Getenv(api.PluginIdxEnvVar), socketPath: api.DefaultSocketPath, dialer: func(p string) (stdnet.Conn, error) { return stdnet.Dial("unix", p) }, - doneC: make(chan struct{}), } for _, o := range opts { @@ -316,10 +316,10 @@ func (stub *stub) Start(ctx context.Context) (retErr error) { stub.Lock() defer stub.Unlock() - if stub.started { + if stub.isStarted() { return fmt.Errorf("stub already started") } - stub.started = true + stub.doneC = make(chan struct{}) err := stub.connect() if err != nil { @@ -401,6 +401,7 @@ func (stub *stub) Start(ctx context.Context) (retErr error) { log.Infof(ctx, "Started plugin %s...", stub.Name()) + stub.started = true return nil } @@ -413,24 +414,42 @@ func (stub *stub) Stop() { stub.close() } +// IsStarted returns true if the plugin has been started either by Start() or by Run(). +func (stub *stub) IsStarted() bool { + stub.Lock() + defer stub.Unlock() + return stub.isStarted() +} + +func (stub *stub) isStarted() bool { + return stub.started +} + +// reset stub to the status that can initiate a new +// NRI connection, the caller must hold lock. func (stub *stub) close() { - stub.closeOnce.Do(func() { - if stub.rpcl != nil { - stub.rpcl.Close() - } - if stub.rpcs != nil { - stub.rpcs.Close() - } - if stub.rpcc != nil { - stub.rpcc.Close() - } - if stub.rpcm != nil { - stub.rpcm.Close() - } - if stub.srvErrC != nil { - <-stub.doneC - } - }) + if !stub.isStarted() { + return + } + + if stub.rpcl != nil { + stub.rpcl.Close() + } + if stub.rpcs != nil { + stub.rpcs.Close() + } + if stub.rpcc != nil { + stub.rpcc.Close() + } + if stub.rpcm != nil { + stub.rpcm.Close() + } + if stub.srvErrC != nil { + <-stub.doneC + } + + stub.started = false + stub.conn = nil } // Run the plugin. Start event processing then wait for an error or getting stopped. @@ -449,14 +468,11 @@ func (stub *stub) Run(ctx context.Context) error { return err } -// Wait for the plugin to stop. +// Wait for the plugin to stop, should be called after Start() or Run(). func (stub *stub) Wait() { - stub.Lock() - if stub.srvErrC == nil { - return + if stub.IsStarted() { + <-stub.doneC } - stub.Unlock() - <-stub.doneC } // Name returns the full indexed name of the plugin. @@ -518,7 +534,9 @@ func (stub *stub) register(ctx context.Context) error { // Handle a lost connection. func (stub *stub) connClosed() { + stub.Lock() stub.close() + stub.Unlock() if stub.onClose != nil { stub.onClose() return