From e8a929856189938d88a6cdf3dd756a23e56c7a55 Mon Sep 17 00:00:00 2001 From: Mykhailo Lohvynenko Date: Mon, 30 Sep 2024 14:03:50 +0300 Subject: [PATCH] [unitstatushandler] Handle unit subjects events Signed-off-by: Mykhailo Lohvynenko --- unitstatushandler/softwaremanager.go | 10 +- unitstatushandler/unitstatushandler.go | 74 +++++++++--- .../unitstatushandler_internal_test.go | 54 ++++++--- unitstatushandler/unitstatushandler_test.go | 108 +++++++++++++++--- 4 files changed, 194 insertions(+), 52 deletions(-) diff --git a/unitstatushandler/softwaremanager.go b/unitstatushandler/softwaremanager.go index aeefa93e..a1f33f8b 100644 --- a/unitstatushandler/softwaremanager.go +++ b/unitstatushandler/softwaremanager.go @@ -83,7 +83,7 @@ type softwareManager struct { statusChannel chan cmserver.UpdateSOTAStatus - nodeManager NodeManager + unitManager UnitManager unitConfigUpdater UnitConfigUpdater downloader softwareDownloader statusHandler softwareStatusHandler @@ -114,7 +114,7 @@ type softwareManager struct { * Interface **********************************************************************************************************************/ -func newSoftwareManager(statusHandler softwareStatusHandler, downloader softwareDownloader, nodeManager NodeManager, +func newSoftwareManager(statusHandler softwareStatusHandler, downloader softwareDownloader, unitManager UnitManager, unitConfigUpdater UnitConfigUpdater, softwareUpdater SoftwareUpdater, instanceRunner InstanceRunner, storage Storage, defaultTTL time.Duration, ) (manager *softwareManager, err error) { @@ -122,7 +122,7 @@ func newSoftwareManager(statusHandler softwareStatusHandler, downloader software statusChannel: make(chan cmserver.UpdateSOTAStatus, 1), downloader: downloader, statusHandler: statusHandler, - nodeManager: nodeManager, + unitManager: unitManager, unitConfigUpdater: unitConfigUpdater, softwareUpdater: softwareUpdater, instanceRunner: instanceRunner, @@ -1499,7 +1499,7 @@ func (manager *softwareManager) updateNodes() (nodesErr error) { if nodeStatus.Status == cloudprotocol.NodeStatusPaused { log.WithField("nodeID", nodeStatus.NodeID).Debug("Pause node") - if err := manager.nodeManager.PauseNode(nodeStatus.NodeID); err != nil && nodesErr == nil { + if err := manager.unitManager.PauseNode(nodeStatus.NodeID); err != nil && nodesErr == nil { log.WithField("nodeID", nodeStatus.NodeID).Errorf("Can't pause node: %v", err) nodesErr = aoserrors.Wrap(err) @@ -1509,7 +1509,7 @@ func (manager *softwareManager) updateNodes() (nodesErr error) { if nodeStatus.Status == cloudprotocol.NodeStatusProvisioned { log.WithField("nodeID", nodeStatus.NodeID).Debug("Resume node") - if err := manager.nodeManager.ResumeNode(nodeStatus.NodeID); err != nil && nodesErr == nil { + if err := manager.unitManager.ResumeNode(nodeStatus.NodeID); err != nil && nodesErr == nil { log.WithField("nodeID", nodeStatus.NodeID).Errorf("Can't resume node: %v", err) nodesErr = aoserrors.Wrap(err) diff --git a/unitstatushandler/unitstatushandler.go b/unitstatushandler/unitstatushandler.go index a6835f1d..b08e7e10 100644 --- a/unitstatushandler/unitstatushandler.go +++ b/unitstatushandler/unitstatushandler.go @@ -41,13 +41,15 @@ import ( * Types **********************************************************************************************************************/ -// NodeManager manages nodes. -type NodeManager interface { +// UnitManager manages unit. +type UnitManager interface { GetAllNodeIDs() ([]string, error) GetNodeInfo(nodeID string) (cloudprotocol.NodeInfo, error) SubscribeNodeInfoChange() <-chan cloudprotocol.NodeInfo PauseNode(nodeID string) error ResumeNode(nodeID string) error + GetUnitSubjects() ([]string, error) + SubscribeUnitSubjectsChanged() <-chan []string } // Downloader downloads packages. @@ -128,7 +130,7 @@ type LayerStatus struct { type Instance struct { sync.Mutex - nodeManager NodeManager + unitManager UnitManager statusSender StatusSender statusMutex sync.Mutex @@ -140,9 +142,10 @@ type Instance struct { firmwareManager *firmwareManager softwareManager *softwareManager - newComponentsChannel <-chan []cloudprotocol.ComponentStatus - nodeChangedChannel <-chan cloudprotocol.NodeInfo - systemQuotaAlertChannel <-chan cloudprotocol.SystemQuotaAlert + newComponentsChannel <-chan []cloudprotocol.ComponentStatus + nodeChangedChannel <-chan cloudprotocol.NodeInfo + unitSubjectsChangedChannel <-chan []string + systemQuotaAlertChannel <-chan cloudprotocol.SystemQuotaAlert initDone bool isConnected bool @@ -155,7 +158,7 @@ type Instance struct { // New creates new unit status handler instance. func New( cfg *config.Config, - nodeManager NodeManager, + unitManager UnitManager, unitConfigUpdater UnitConfigUpdater, firmwareUpdater FirmwareUpdater, softwareUpdater SoftwareUpdater, @@ -168,12 +171,13 @@ func New( log.Debug("Create unit status handler") instance = &Instance{ - nodeManager: nodeManager, - statusSender: statusSender, - sendStatusPeriod: cfg.UnitStatusSendTimeout.Duration, - newComponentsChannel: firmwareUpdater.NewComponentsChannel(), - nodeChangedChannel: nodeManager.SubscribeNodeInfoChange(), - systemQuotaAlertChannel: systemQuotaAlertProvider.GetSystemQuoteAlertChannel(), + unitManager: unitManager, + statusSender: statusSender, + sendStatusPeriod: cfg.UnitStatusSendTimeout.Duration, + newComponentsChannel: firmwareUpdater.NewComponentsChannel(), + nodeChangedChannel: unitManager.SubscribeNodeInfoChange(), + unitSubjectsChangedChannel: unitManager.SubscribeUnitSubjectsChanged(), + systemQuotaAlertChannel: systemQuotaAlertProvider.GetSystemQuoteAlertChannel(), } instance.resetUnitStatus() @@ -185,7 +189,7 @@ func New( return nil, aoserrors.Wrap(err) } - if instance.softwareManager, err = newSoftwareManager(instance, groupDownloader, nodeManager, unitConfigUpdater, + if instance.softwareManager, err = newSoftwareManager(instance, groupDownloader, unitManager, unitConfigUpdater, softwareUpdater, instanceRunner, storage, cfg.SMController.UpdateTTL.Duration); err != nil { return nil, aoserrors.Wrap(err) } @@ -448,6 +452,19 @@ func (instance *Instance) initNodesStatus() error { return nil } +func (instance *Instance) initUnitSubjects() error { + subjects, err := instance.unitManager.GetUnitSubjects() + if err != nil { + log.Errorf("Can't get unit subjects: %v", err) + } + + if len(subjects) > 0 { + instance.setSubjects(subjects) + } + + return nil +} + func (instance *Instance) initCurrentStatus() error { if err := instance.initUnitConfigStatus(); err != nil { return err @@ -473,6 +490,10 @@ func (instance *Instance) initCurrentStatus() error { return err } + if err := instance.initUnitSubjects(); err != nil { + return err + } + return nil } @@ -636,6 +657,20 @@ func (instance *Instance) updateNodeInfo(nodeInfo cloudprotocol.NodeInfo) { instance.statusChanged() } +func (instance *Instance) setSubjects(subjects []string) { + instance.statusMutex.Lock() + defer instance.statusMutex.Unlock() + + log.WithField("subjects", subjects).Debug("Set subjects") + + instance.unitStatus.UnitSubjects = subjects +} + +func (instance *Instance) updateSubjects(subjects []string) { + instance.setSubjects(subjects) + instance.statusChanged() +} + func (instance *Instance) statusChanged() { if instance.statusTimer != nil { return @@ -686,7 +721,7 @@ func (instance *Instance) sendCurrentStatus(deltaStatus bool) { } func (instance *Instance) getAllNodesInfo() ([]cloudprotocol.NodeInfo, error) { - nodeIDs, err := instance.nodeManager.GetAllNodeIDs() + nodeIDs, err := instance.unitManager.GetAllNodeIDs() if err != nil { return nil, aoserrors.Wrap(err) } @@ -694,7 +729,7 @@ func (instance *Instance) getAllNodesInfo() ([]cloudprotocol.NodeInfo, error) { nodesInfo := make([]cloudprotocol.NodeInfo, 0, len(nodeIDs)) for _, nodeID := range nodeIDs { - nodeInfo, err := instance.nodeManager.GetNodeInfo(nodeID) + nodeInfo, err := instance.unitManager.GetNodeInfo(nodeID) if err != nil { log.WithField("nodeID", nodeID).Errorf("Can't get node info: %s", err) continue @@ -744,6 +779,13 @@ func (instance *Instance) handleChannels() { log.Errorf("Can't perform rebalancing: %v", err) } + case subjects, ok := <-instance.unitSubjectsChangedChannel: + if !ok { + return + } + + instance.updateSubjects(subjects) + case systemQuotaAlert, ok := <-instance.systemQuotaAlertChannel: if !ok { return diff --git a/unitstatushandler/unitstatushandler_internal_test.go b/unitstatushandler/unitstatushandler_internal_test.go index f3c744da..2c8a4a12 100644 --- a/unitstatushandler/unitstatushandler_internal_test.go +++ b/unitstatushandler/unitstatushandler_internal_test.go @@ -51,9 +51,11 @@ const waitStatusTimeout = 5 * time.Second * Types **********************************************************************************************************************/ -type TestNodeManager struct { +type TestUnitManager struct { nodesInfo map[string]*cloudprotocol.NodeInfo nodeInfoChannel chan cloudprotocol.NodeInfo + subjectsChannel chan []string + currentSubjects []string } type TestSender struct { @@ -1276,10 +1278,11 @@ func TestSoftwareManager(t *testing.T) { }, } - nodeManager := NewTestNodeManager([]cloudprotocol.NodeInfo{ + unitManager := NewTestUnitManager([]cloudprotocol.NodeInfo{ {NodeID: "node1", NodeType: "type1", Status: cloudprotocol.NodeStatusProvisioned}, {NodeID: "node2", NodeType: "type2", Status: cloudprotocol.NodeStatusProvisioned}, - }) + }, + nil) unitConfigUpdater := NewTestUnitConfigUpdater(cloudprotocol.UnitConfigStatus{}) softwareUpdater := NewTestSoftwareUpdater(nil, nil) instanceRunner := NewTestInstanceRunner() @@ -1304,7 +1307,7 @@ func TestSoftwareManager(t *testing.T) { // Create software manager - softwareManager, err := newSoftwareManager(newTestStatusHandler(), softwareDownloader, nodeManager, + softwareManager, err := newSoftwareManager(newTestStatusHandler(), softwareDownloader, unitManager, unitConfigUpdater, softwareUpdater, instanceRunner, testStorage, 30*time.Second) if err != nil { t.Errorf("Can't create software manager: %s", err) @@ -1362,7 +1365,7 @@ func TestSoftwareManager(t *testing.T) { if item.desiredStatus != nil && item.desiredStatus.Nodes != nil { for _, nodeStatus := range item.desiredStatus.Nodes { - nodeInfo, err := nodeManager.GetNodeInfo(nodeStatus.NodeID) + nodeInfo, err := unitManager.GetNodeInfo(nodeStatus.NodeID) if err != nil { t.Errorf("Get node info error: %v", err) } @@ -1641,10 +1644,10 @@ func TestSyncExecutor(t *testing.T) { **********************************************************************************************************************/ /*********************************************************************************************************************** - * TestNodeManager + * TestUnitManager **********************************************************************************************************************/ -func NewTestNodeManager(nodesInfo []cloudprotocol.NodeInfo) *TestNodeManager { +func NewTestUnitManager(nodesInfo []cloudprotocol.NodeInfo, subjects []string) *TestUnitManager { nodesInfoMap := make(map[string]*cloudprotocol.NodeInfo) for _, nodeInfo := range nodesInfo { @@ -1652,12 +1655,13 @@ func NewTestNodeManager(nodesInfo []cloudprotocol.NodeInfo) *TestNodeManager { *nodesInfoMap[nodeInfo.NodeID] = nodeInfo } - return &TestNodeManager{ - nodesInfo: nodesInfoMap, + return &TestUnitManager{ + nodesInfo: nodesInfoMap, + currentSubjects: subjects, } } -func (manager *TestNodeManager) GetAllNodeIDs() ([]string, error) { +func (manager *TestUnitManager) GetAllNodeIDs() ([]string, error) { nodeIDs := make([]string, 0, len(manager.nodesInfo)) for nodeID := range manager.nodesInfo { @@ -1667,7 +1671,7 @@ func (manager *TestNodeManager) GetAllNodeIDs() ([]string, error) { return nodeIDs, nil } -func (manager *TestNodeManager) GetNodeInfo(nodeID string) (cloudprotocol.NodeInfo, error) { +func (manager *TestUnitManager) GetNodeInfo(nodeID string) (cloudprotocol.NodeInfo, error) { nodeInfo, ok := manager.nodesInfo[nodeID] if !ok { return cloudprotocol.NodeInfo{}, aoserrors.New("node not found") @@ -1676,13 +1680,13 @@ func (manager *TestNodeManager) GetNodeInfo(nodeID string) (cloudprotocol.NodeIn return *nodeInfo, nil } -func (manager *TestNodeManager) SubscribeNodeInfoChange() <-chan cloudprotocol.NodeInfo { +func (manager *TestUnitManager) SubscribeNodeInfoChange() <-chan cloudprotocol.NodeInfo { manager.nodeInfoChannel = make(chan cloudprotocol.NodeInfo, 1) return manager.nodeInfoChannel } -func (manager *TestNodeManager) NodeInfoChanged(nodeInfo cloudprotocol.NodeInfo) { +func (manager *TestUnitManager) NodeInfoChanged(nodeInfo cloudprotocol.NodeInfo) { if _, ok := manager.nodesInfo[nodeInfo.NodeID]; !ok { manager.nodesInfo[nodeInfo.NodeID] = &cloudprotocol.NodeInfo{} } @@ -1694,7 +1698,7 @@ func (manager *TestNodeManager) NodeInfoChanged(nodeInfo cloudprotocol.NodeInfo) } } -func (manager *TestNodeManager) GetAllNodesInfo() []cloudprotocol.NodeInfo { +func (manager *TestUnitManager) GetAllNodesInfo() []cloudprotocol.NodeInfo { nodesInfo := make([]cloudprotocol.NodeInfo, 0, len(manager.nodesInfo)) for _, nodeInfo := range manager.nodesInfo { @@ -1704,7 +1708,7 @@ func (manager *TestNodeManager) GetAllNodesInfo() []cloudprotocol.NodeInfo { return nodesInfo } -func (manager *TestNodeManager) PauseNode(nodeID string) error { +func (manager *TestUnitManager) PauseNode(nodeID string) error { if _, ok := manager.nodesInfo[nodeID]; !ok { return aoserrors.New("node not found") } @@ -1714,7 +1718,7 @@ func (manager *TestNodeManager) PauseNode(nodeID string) error { return nil } -func (manager *TestNodeManager) ResumeNode(nodeID string) error { +func (manager *TestUnitManager) ResumeNode(nodeID string) error { if _, ok := manager.nodesInfo[nodeID]; !ok { return aoserrors.New("node not found") } @@ -1724,6 +1728,24 @@ func (manager *TestNodeManager) ResumeNode(nodeID string) error { return nil } +func (manager *TestUnitManager) GetUnitSubjects() (subjects []string, err error) { + return manager.currentSubjects, nil +} + +func (manager *TestUnitManager) SubscribeUnitSubjectsChanged() <-chan []string { + manager.subjectsChannel = make(chan []string, 1) + + return manager.subjectsChannel +} + +func (manager *TestUnitManager) SubjectsChanged(subjects []string) { + manager.currentSubjects = subjects + + if manager.subjectsChannel != nil { + manager.subjectsChannel <- subjects + } +} + /*********************************************************************************************************************** * TestSender **********************************************************************************************************************/ diff --git a/unitstatushandler/unitstatushandler_test.go b/unitstatushandler/unitstatushandler_test.go index 90c929c1..8cad8d52 100644 --- a/unitstatushandler/unitstatushandler_test.go +++ b/unitstatushandler/unitstatushandler_test.go @@ -50,8 +50,10 @@ var cfg = &config.Config{UnitStatusSendTimeout: aostypes.Duration{Duration: 3 * **********************************************************************************************************************/ func TestSendInitialStatus(t *testing.T) { + initialSubjects := []string{"initialSubject1"} + expectedUnitStatus := cloudprotocol.UnitStatus{ - UnitSubjects: []string{"subject1"}, + UnitSubjects: initialSubjects, UnitConfig: []cloudprotocol.UnitConfigStatus{ {Version: "1.0.0", Status: cloudprotocol.InstalledStatus}, }, @@ -110,8 +112,9 @@ func TestSendInitialStatus(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, fotaUpdater, sotaUpdater, - instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, + cfg, unitstatushandler.NewTestUnitManager(nil, initialSubjects), + unitConfigUpdater, fotaUpdater, sotaUpdater, instanceRunner, + unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { t.Fatalf("Can't create unit status handler: %s", err) @@ -157,7 +160,8 @@ func TestUpdateUnitConfig(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, fotaUpdater, sotaUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, fotaUpdater, sotaUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -252,7 +256,7 @@ func TestUpdateComponents(t *testing.T) { instanceRunner := unitstatushandler.NewTestInstanceRunner() sender := unitstatushandler.NewTestSender() - statusHandler, err := unitstatushandler.New(cfg, unitstatushandler.NewTestNodeManager(nil), + statusHandler, err := unitstatushandler.New(cfg, unitstatushandler.NewTestUnitManager(nil, nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -344,7 +348,8 @@ func TestUpdateLayers(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -508,7 +513,8 @@ func TestUpdateServices(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -669,7 +675,8 @@ func TestRunInstances(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -791,7 +798,8 @@ func TestRevertServices(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -905,7 +913,8 @@ func TestUpdateInstancesStatus(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -1018,7 +1027,8 @@ func TestUpdateCachedSOTA(t *testing.T) { downloader := unitstatushandler.NewTestDownloader() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, downloader, unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -1141,7 +1151,8 @@ func TestNewComponents(t *testing.T) { sender := unitstatushandler.NewTestSender() statusHandler, err := unitstatushandler.New( - cfg, unitstatushandler.NewTestNodeManager(nil), unitConfigUpdater, firmwareUpdater, softwareUpdater, + cfg, unitstatushandler.NewTestUnitManager(nil, nil), + unitConfigUpdater, firmwareUpdater, softwareUpdater, instanceRunner, unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) if err != nil { @@ -1200,7 +1211,7 @@ func TestNewComponents(t *testing.T) { } if err = compareUnitStatus(receivedUnitStatus, expectedUnitStatus); err != nil { - t.Errorf("Wrong unit status received: %v, expected: %v", receivedUnitStatus, expectedUnitStatus) + t.Errorf("Wrong unit status received: %v, expected: %v, err: %v", receivedUnitStatus, expectedUnitStatus, err) } } @@ -1211,10 +1222,11 @@ func TestNodeInfoChanged(t *testing.T) { softwareUpdater := unitstatushandler.NewTestSoftwareUpdater(nil, nil) instanceRunner := unitstatushandler.NewTestInstanceRunner() sender := unitstatushandler.NewTestSender() - nodeInfoProvider := unitstatushandler.NewTestNodeManager([]cloudprotocol.NodeInfo{ + nodeInfoProvider := unitstatushandler.NewTestUnitManager([]cloudprotocol.NodeInfo{ {NodeID: "node1", NodeType: "type1", Status: cloudprotocol.NodeStatusProvisioned}, {NodeID: "node2", NodeType: "type2", Status: cloudprotocol.NodeStatusProvisioned}, - }) + }, + nil) statusHandler, err := unitstatushandler.New( cfg, nodeInfoProvider, unitConfigUpdater, firmwareUpdater, softwareUpdater, @@ -1274,6 +1286,65 @@ func TestNodeInfoChanged(t *testing.T) { } } +func TestSubjectsChanged(t *testing.T) { + initialSubjects := []string{"initial1", "initial2"} + + unitConfigUpdater := unitstatushandler.NewTestUnitConfigUpdater( + cloudprotocol.UnitConfigStatus{Version: "1.0.0", Status: cloudprotocol.InstalledStatus}) + unitManager := unitstatushandler.NewTestUnitManager(nil, initialSubjects) + sender := unitstatushandler.NewTestSender() + + statusHandler, err := unitstatushandler.New( + cfg, unitManager, unitConfigUpdater, + unitstatushandler.NewTestFirmwareUpdater(nil), unitstatushandler.NewTestSoftwareUpdater(nil, nil), + unitstatushandler.NewTestInstanceRunner(), unitstatushandler.NewTestDownloader(), unitstatushandler.NewTestStorage(), + sender, unitstatushandler.NewTestSystemQuotaAlertProvider()) + if err != nil { + t.Fatalf("Can't create unit status handler: %v", err) + } + defer statusHandler.Close() + + sender.Consumer.CloudConnected() + + go handleUpdateStatus(statusHandler) + + if err := statusHandler.ProcessRunStatus(nil); err != nil { + t.Fatalf("Can't process run status: %v", err) + } + + receivedUnitStatus, err := sender.WaitForStatus(waitStatusTimeout) + if err != nil { + t.Fatalf("Can't receive unit status: %v", err) + } + + expectedUnitStatus := cloudprotocol.UnitStatus{ + UnitConfig: []cloudprotocol.UnitConfigStatus{unitConfigUpdater.UnitConfigStatus}, + UnitSubjects: initialSubjects, + } + + if err = compareUnitStatus(receivedUnitStatus, expectedUnitStatus); err != nil { + t.Errorf("Wrong unit status received: %v, expected: %v", receivedUnitStatus, expectedUnitStatus) + } + + newSubjects := []string{"subject1", "subject2", "subject3"} + + unitManager.SubjectsChanged(newSubjects) + + receivedUnitStatus, err = sender.WaitForStatus(waitStatusTimeout) + if err != nil { + t.Fatalf("Can't receive unit status: %v", err) + } + + expectedUnitStatus = cloudprotocol.UnitStatus{ + UnitSubjects: newSubjects, + IsDeltaInfo: true, + } + + if err = compareUnitStatus(receivedUnitStatus, expectedUnitStatus); err != nil { + t.Errorf("Wrong unit status received: %v, expected: %v", receivedUnitStatus, expectedUnitStatus) + } +} + /*********************************************************************************************************************** * Private **********************************************************************************************************************/ @@ -1362,6 +1433,13 @@ func compareUnitStatus(status1, status2 cloudprotocol.UnitStatus) (err error) { return aoserrors.Wrap(err) } + if err = compareStatus(len(status1.UnitSubjects), len(status2.UnitSubjects), + func(index1, index2 int) (result bool) { + return reflect.DeepEqual(status1.UnitSubjects[index1], status2.UnitSubjects[index2]) + }); err != nil { + return aoserrors.Wrap(err) + } + if status1.IsDeltaInfo != status2.IsDeltaInfo { return aoserrors.New("IsDeltaInfo mismatch") }