diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index e2ff8ace..7fee3630 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -22,7 +22,7 @@ jobs: - name: Prepare run: | sudo apt install lcov libsofthsm2 libsystemd-dev cppcheck -y - pip install conan + python -m pip install conan - name: Build and test run: | diff --git a/.github/workflows/check_format.yaml b/.github/workflows/check_format.yaml index 5da2cbde..f86d7070 100644 --- a/.github/workflows/check_format.yaml +++ b/.github/workflows/check_format.yaml @@ -24,6 +24,6 @@ jobs: - name: Run cmake-format style check run: | python -m pip install --upgrade pip - pip install cmake_format + python -m pip install cmake_format find . \( \( -not -path '*/build/*' \) -name '*.cmake' -or -name 'CMakeLists.txt' \) \ -exec cmake-format --check {} +; diff --git a/cmake/AosCoreAPI.cmake b/cmake/AosCoreAPI.cmake index a36bbca2..c0c2a24a 100644 --- a/cmake/AosCoreAPI.cmake +++ b/cmake/AosCoreAPI.cmake @@ -12,8 +12,8 @@ include(FetchContent) FetchContent_Declare( aoscoreapi - GIT_REPOSITORY https://github.com/aosedge/aos_core_api.git - GIT_TAG develop + GIT_REPOSITORY https://github.com/mykola-kobets-epam/aos_core_api.git + GIT_TAG renew-cert-fix GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) diff --git a/cmake/AosCoreCommon.cmake b/cmake/AosCoreCommon.cmake index fd20d6e0..d48eda23 100644 --- a/cmake/AosCoreCommon.cmake +++ b/cmake/AosCoreCommon.cmake @@ -12,8 +12,8 @@ set(aoscorecommon_build_dir ${CMAKE_CURRENT_BINARY_DIR}/aoscorecommon) ExternalProject_Add( aoscorecommon PREFIX ${aoscorecommon_build_dir} - GIT_REPOSITORY https://github.com/aosedge/aos_core_common_cpp.git - GIT_TAG develop + GIT_REPOSITORY https://github.com/mykola-kobets-epam/aos_core_common_cpp.git + GIT_TAG renew-cert-fix GIT_PROGRESS TRUE GIT_SHALLOW TRUE CMAKE_ARGS -Daoscore_build_dir=${aoscore_build_dir} diff --git a/cmake/AosCoreLib.cmake b/cmake/AosCoreLib.cmake index 0b02a8d8..01fe73b3 100644 --- a/cmake/AosCoreLib.cmake +++ b/cmake/AosCoreLib.cmake @@ -12,8 +12,8 @@ set(aoscore_build_dir ${CMAKE_CURRENT_BINARY_DIR}/aoscore) ExternalProject_Add( aoscore PREFIX ${aoscore_build_dir} - GIT_REPOSITORY https://github.com/aosedge/aos_core_lib_cpp.git - GIT_TAG develop + GIT_REPOSITORY https://github.com/mykola-kobets-epam/aos_core_lib_cpp.git + GIT_TAG renew-cert-fix GIT_PROGRESS TRUE GIT_SHALLOW TRUE CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${aoscore_build_dir} diff --git a/src/iamclient/iamclient.cpp b/src/iamclient/iamclient.cpp index 5bc354c9..0cd52892 100644 --- a/src/iamclient/iamclient.cpp +++ b/src/iamclient/iamclient.cpp @@ -26,6 +26,19 @@ aos::Error IAMClient::Init(const Config& config, aos::iam::identhandler::IdentHa aos::crypto::x509::ProviderItf& cryptoProvider, aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, bool provisioningMode) { + mIdentHandler = identHandler; + mNodeInfoProvider = &nodeInfoProvider; + mCertLoader = &certLoader; + mCryptoProvider = &cryptoProvider; + mProvisionManager = &provisionManager; + + mStartProvisioningCmdArgs = config.mStartProvisioningCmdArgs; + mDiskEncryptionCmdArgs = config.mDiskEncryptionCmdArgs; + mFinishProvisioningCmdArgs = config.mFinishProvisioningCmdArgs; + mDeprovisionCmdArgs = config.mDeprovisionCmdArgs; + mReconnectInterval = config.mNodeReconnectInterval; + mCACert = config.mCACert; + if (provisioningMode) { mCredentialList.push_back(grpc::InsecureChannelCredentials()); if (!config.mCACert.empty()) { @@ -43,21 +56,19 @@ aos::Error IAMClient::Init(const Config& config, aos::iam::identhandler::IdentHa return AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument); } + err = provisionManager.SubscribeCertChanged(aos::String(config.mCertStorage.c_str()), *this); + if (!err.IsNone()) { + LOG_ERR() << "Subscribe certificate receiver failed: error=" << err.Message(); + + return AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument); + } + mCredentialList.push_back( aos::common::utils::GetMTLSClientCredentials(certInfo, config.mCACert.c_str(), certLoader, cryptoProvider)); + mServerURL = config.mMainIAMProtectedServerURL; } - mIdentHandler = identHandler; - mNodeInfoProvider = &nodeInfoProvider; - mProvisionManager = &provisionManager; - - mStartProvisioningCmdArgs = config.mStartProvisioningCmdArgs; - mDiskEncryptionCmdArgs = config.mDiskEncryptionCmdArgs; - mFinishProvisioningCmdArgs = config.mFinishProvisioningCmdArgs; - mDeprovisionCmdArgs = config.mDeprovisionCmdArgs; - mReconnectInterval = config.mNodeReconnectInterval; - mConnectionThread = std::thread(&IAMClient::ConnectionLoop, this); return aos::ErrorEnum::eNone; @@ -71,6 +82,8 @@ IAMClient::~IAMClient() mShutdown = true; mShutdownCV.notify_all(); + mProvisionManager->UnsubscribeCertChanged(*this); + if (mRegisterNodeCtx) { mRegisterNodeCtx->TryCancel(); } @@ -85,6 +98,17 @@ IAMClient::~IAMClient() * Private **********************************************************************************************************************/ +void IAMClient::OnCertChanged(const aos::iam::certhandler::CertInfo& info) +{ + std::unique_lock lock {mShutdownLock}; + + mCredentialList.clear(); + mCredentialList.push_back( + aos::common::utils::GetMTLSClientCredentials(info, mCACert.c_str(), *mCertLoader, *mCryptoProvider)); + + mCredentialListUpdated = true; +} + std::unique_ptr IAMClient::CreateClientContext() { return std::make_unique(); @@ -135,6 +159,8 @@ bool IAMClient::RegisterNode(const std::string& url) LOG_DBG() << "Connection established"; + mCredentialListUpdated = false; + return true; } @@ -196,6 +222,18 @@ void IAMClient::HandleIncomingMessages() noexcept if (!ok) { break; } + + { + std::unique_lock lock {mShutdownLock}; + + if (mCredentialListUpdated) { + LOG_DBG() << "Credential list updated: closing connection"; + + mRegisterNodeCtx->TryCancel(); + + break; + } + } } } catch (const std::exception& e) { diff --git a/src/iamclient/iamclient.hpp b/src/iamclient/iamclient.hpp index 3fbf9f56..894258cd 100644 --- a/src/iamclient/iamclient.hpp +++ b/src/iamclient/iamclient.hpp @@ -32,7 +32,7 @@ using PublicNodeServiceStubPtr = std::unique_ptr>; @@ -89,15 +91,20 @@ class IAMClient { aos::iam::identhandler::IdentHandlerItf* mIdentHandler = nullptr; aos::iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; + aos::cryptoutils::CertLoaderItf* mCertLoader = nullptr; + aos::crypto::x509::ProviderItf* mCryptoProvider = nullptr; aos::iam::nodeinfoprovider::NodeInfoProviderItf* mNodeInfoProvider = nullptr; - std::vector mStartProvisioningCmdArgs; - std::vector mDiskEncryptionCmdArgs; - std::vector mFinishProvisioningCmdArgs; - std::vector mDeprovisionCmdArgs; - aos::common::utils::Duration mReconnectInterval; - std::string mServerURL; std::vector> mCredentialList; + bool mCredentialListUpdated = false; + + std::vector mStartProvisioningCmdArgs; + std::vector mDiskEncryptionCmdArgs; + std::vector mFinishProvisioningCmdArgs; + std::vector mDeprovisionCmdArgs; + aos::common::utils::Duration mReconnectInterval; + std::string mServerURL; + std::string mCACert; std::unique_ptr mRegisterNodeCtx; StreamPtr mStream; diff --git a/src/iamserver/iamserver.cpp b/src/iamserver/iamserver.cpp index 92c9c3b6..8b7ba04f 100644 --- a/src/iamserver/iamserver.cpp +++ b/src/iamserver/iamserver.cpp @@ -92,7 +92,9 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand { LOG_DBG() << "IAM Server init"; - mConfig = config; + mConfig = config; + mCertLoader = &certLoader; + mCryptoProvider = &cryptoProvider; aos::Error err; aos::NodeInfo nodeInfo; @@ -118,8 +120,6 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand } try { - std::shared_ptr publicOpt, protectedOpt; - if (!provisioningMode) { aos::iam::certhandler::CertInfo certInfo; @@ -128,16 +128,21 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand return AOS_ERROR_WRAP(err); } - publicOpt = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider); - protectedOpt = aos::common::utils::GetMTLSServerCredentials( + err = certHandler.SubscribeCertChanged(aos::String(mConfig.mCertStorage.c_str()), *this); + if (!err.IsNone()) { + return AOS_ERROR_WRAP(err); + } + + mPublicCred = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider); + mProtectedCred = aos::common::utils::GetMTLSServerCredentials( certInfo, mConfig.mCACert.c_str(), certLoader, cryptoProvider); } else { - publicOpt = grpc::InsecureServerCredentials(); - protectedOpt = grpc::InsecureServerCredentials(); + mPublicCred = grpc::InsecureServerCredentials(); + mProtectedCred = grpc::InsecureServerCredentials(); } - CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), publicOpt); - CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), protectedOpt); + Start(); + } catch (const aos::common::utils::AosException& e) { return e.GetError(); } catch (const std::exception& e) { @@ -206,22 +211,7 @@ void IAMServer::OnNodeRemoved(const aos::String& id) IAMServer::~IAMServer() { - LOG_DBG() << "IAM Server shutdown"; - - mNodeController.Close(); - - if (mPublicServer) { - mPublicServer->Shutdown(); - mPublicServer->Wait(); - } - - if (mProtectedServer) { - mProtectedServer->Shutdown(); - mProtectedServer->Wait(); - } - - mPublicMessageHandler.Close(); - mProtectedMessageHandler.Close(); + Shutdown(); } /*********************************************************************************************************************** @@ -242,6 +232,65 @@ aos::Error IAMServer::SubjectsChanged(const aos::ArrayShutdown(); + mPublicServer->Wait(); + } + + if (mProtectedServer) { + mProtectedServer->Shutdown(); + mProtectedServer->Wait(); + } + + mIsStarted = false; +} + void IAMServer::CreatePublicServer(const std::string& addr, const std::shared_ptr& credentials) { LOG_DBG() << "Process create public server: URL=" << addr.c_str(); diff --git a/src/iamserver/iamserver.hpp b/src/iamserver/iamserver.hpp index d7b6b81e..38206361 100644 --- a/src/iamserver/iamserver.hpp +++ b/src/iamserver/iamserver.hpp @@ -32,7 +32,8 @@ */ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, public aos::iam::identhandler::SubjectsObserverItf, - public aos::iam::provisionmanager::ProvisionManagerCallbackItf { + public aos::iam::provisionmanager::ProvisionManagerCallbackItf, + private aos::iam::certhandler::CertReceiverItf { public: /** * Constructor. @@ -115,15 +116,29 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, // identhandler::SubjectsObserverItf interface aos::Error SubjectsChanged(const aos::Array>& messages) override; + // certhandler::CertReceiverItf interface + void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override; + + // lifecycle routines + void Start(); + void Shutdown(); + // creating routines void CreatePublicServer(const std::string& addr, const std::shared_ptr& credentials); void CreateProtectedServer(const std::string& addr, const std::shared_ptr& credentials); - Config mConfig; - NodeController mNodeController; - PublicMessageHandler mPublicMessageHandler; - ProtectedMessageHandler mProtectedMessageHandler; - std::unique_ptr mPublicServer, mProtectedServer; + Config mConfig; + aos::cryptoutils::CertLoader* mCertLoader; + aos::crypto::x509::ProviderItf* mCryptoProvider; + + NodeController mNodeController; + PublicMessageHandler mPublicMessageHandler; + ProtectedMessageHandler mProtectedMessageHandler; + std::unique_ptr mPublicServer, mProtectedServer; + std::shared_ptr mPublicCred, mProtectedCred; + + bool mIsStarted = false; + std::future mCertChangedResult; }; #endif diff --git a/src/iamserver/nodecontroller.cpp b/src/iamserver/nodecontroller.cpp index 7c5663d3..a5b5becd 100644 --- a/src/iamserver/nodecontroller.cpp +++ b/src/iamserver/nodecontroller.cpp @@ -339,10 +339,24 @@ void NodeStreamHandler::SetNodeID(const std::string& nodeID) * Public **********************************************************************************************************************/ +NodeController::NodeController() +{ + Start(); +} + +void NodeController::Start() +{ + std::lock_guard lock {mMutex}; + + mIsClosed = false; +} + void NodeController::Close() { std::lock_guard lock {mMutex}; + mIsClosed = true; + // Call Close method explicitly to avoid hanging on shutdown. // HandleRegisterNodeStream method references handler so destructor is not called here. for (auto& handler : mHandlers) { @@ -355,6 +369,16 @@ void NodeController::Close() grpc::Status NodeController::HandleRegisterNodeStream(const std::vector& allowedStatuses, NodeServerReaderWriter* stream, grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager) { + { + std::lock_guard lock {mMutex}; + + if (mIsClosed) { + LOG_DBG() << "Node controller closed, cancel node registration."; + + return grpc::Status::CANCELLED; + } + } + auto handler = std::make_shared(allowedStatuses, stream, context, nodeManager); StoreHandler(handler); diff --git a/src/iamserver/nodecontroller.hpp b/src/iamserver/nodecontroller.hpp index ba2b1ac8..576e0cc4 100644 --- a/src/iamserver/nodecontroller.hpp +++ b/src/iamserver/nodecontroller.hpp @@ -178,6 +178,16 @@ using NodeStreamHandlerPtr = std::shared_ptr; */ class NodeController { public: + /** + * Constructor. + */ + NodeController(); + + /** + * Starts node controller. + */ + void Start(); + /** * Closes all stream handlers. */ @@ -209,6 +219,7 @@ class NodeController { void StoreHandler(NodeStreamHandlerPtr handler); void RemoveHandler(NodeStreamHandlerPtr handler); + bool mIsClosed = false; std::mutex mMutex; std::vector mHandlers; }; diff --git a/src/iamserver/protectedmessagehandler.hpp b/src/iamserver/protectedmessagehandler.hpp index 946d970d..e0998c93 100644 --- a/src/iamserver/protectedmessagehandler.hpp +++ b/src/iamserver/protectedmessagehandler.hpp @@ -31,7 +31,7 @@ */ class ProtectedMessageHandler : // public services - private PublicMessageHandler, + public PublicMessageHandler, // protected services private iamproto::IAMNodesService::Service, private iamproto::IAMProvisioningService::Service, diff --git a/src/iamserver/publicmessagehandler.cpp b/src/iamserver/publicmessagehandler.cpp index f893d982..14337683 100644 --- a/src/iamserver/publicmessagehandler.cpp +++ b/src/iamserver/publicmessagehandler.cpp @@ -90,12 +90,28 @@ aos::Error PublicMessageHandler::SubjectsChanged(const aos::ArrayClose(); + } + + mCertWriters.clear(); + } } /*********************************************************************************************************************** @@ -152,7 +168,7 @@ grpc::Status PublicMessageHandler::GetNodeInfo([[maybe_unused]] grpc::ServerCont } grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext* context, - const iamproto::GetCertRequest* request, iamproto::GetCertResponse* response) + const iamproto::GetCertRequest* request, iamproto::CertInfo* response) { LOG_DBG() << "Process get cert request: type=" << request->type().c_str() << ", serial=" << request->serial().c_str(); @@ -185,6 +201,45 @@ grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext* return grpc::Status::OK; } +grpc::Status PublicMessageHandler::SubscribeCertChanged([[maybe_unused]] grpc::ServerContext* context, + const iamanager::v5::SubscribeCertChangedRequest* request, grpc::ServerWriter* writer) +{ + LOG_DBG() << "Process subscribe cert changed: type=" << request->type().c_str(); + + auto certWriter = std::make_shared(request->type()); + + { + std::lock_guard lock {mCertWritersLock}; + + mCertWriters.push_back(certWriter); + } + + auto err = GetProvisionManager()->SubscribeCertChanged(request->type().c_str(), *certWriter); + if (!err.IsNone()) { + LOG_ERR() << "Failed to subscribe cert changed: " << err; + + return utils::ConvertAosErrorToGrpcStatus(err); + } + + auto status = certWriter->HandleStream(context, writer); + + err = GetProvisionManager()->UnsubscribeCertChanged(*certWriter); + if (!err.IsNone()) { + LOG_ERR() << "Failed to unsubscribe cert changed: " << err; + + return utils::ConvertAosErrorToGrpcStatus(err); + } + + { + std::lock_guard lock {mCertWritersLock}; + + auto iter = std::remove(mCertWriters.begin(), mCertWriters.end(), certWriter); + mCertWriters.erase(iter, mCertWriters.end()); + } + + return status; +} + /*********************************************************************************************************************** * IAMPublicIdentityService implementation **********************************************************************************************************************/ diff --git a/src/iamserver/publicmessagehandler.hpp b/src/iamserver/publicmessagehandler.hpp index d8a90fd6..6da46d82 100644 --- a/src/iamserver/publicmessagehandler.hpp +++ b/src/iamserver/publicmessagehandler.hpp @@ -26,6 +26,7 @@ #include #include "nodecontroller.hpp" +#include "streamwriter.hpp" /** * Public message handler. Responsible for handling public IAM services. @@ -88,90 +89,16 @@ class PublicMessageHandler : aos::Error SubjectsChanged(const aos::Array>& messages) override; /** - * Closes public message handler. + * Start public message handler. */ - void Close(); + void Start(); -protected: /** - * Server writer controller handles server writer streams. + * Closes public message handler. */ - template - class ServerWriterController { - public: - /** - * Closes all streams. - */ - void Close() - { - mIsRunning = false; - - { - std::unique_lock lock {mMutex}; - - mLastMessage.reset(); - } - - mCV.notify_all(); - } - - /** - * Writes notification message to all streams. - * - * @param message notification message. - */ - void WriteToStreams(const T& message) - { - { - std::unique_lock lock {mMutex}; - - ++mNotificationID; - - mLastMessage = message; - } - - mCV.notify_all(); - } - - /** - * Handles stream. Blocks the caller until the stream is closed. - * - * @param context server context. - * @param writer server writer. - * @return grpc::Status. - */ - grpc::Status HandleStream(grpc::ServerContext* context, grpc::ServerWriter* writer) - { - uint32_t lastNotificationID = 0; - - while (mIsRunning && !context->IsCancelled()) { - std::shared_lock lock {mMutex}; - - if (mCV.wait_for(lock, cWaitTimeout, [this, lastNotificationID] { - return mNotificationID != lastNotificationID && mLastMessage.has_value(); - })) { - // got notification, send it to the client - if (!writer->Write(*mLastMessage)) { - break; - } - - lastNotificationID = mNotificationID; - } - } - - return grpc::Status::OK; - } - - private: - static constexpr auto cWaitTimeout = std::chrono::seconds(10); - - std::atomic_bool mIsRunning = true; - std::condition_variable_any mCV; - std::shared_mutex mMutex; - std::atomic_uint32_t mNotificationID = 0; - std::optional mLastMessage; - }; + void Close(); +protected: aos::iam::identhandler::IdentHandlerItf* GetIdentHandler() { return mIdentHandler; } aos::iam::permhandler::PermHandlerItf* GetPermHandler() { return mPermHandler; } aos::iam::nodeinfoprovider::NodeInfoProviderItf* GetNodeInfoProvider() { return mNodeInfoProvider; } @@ -189,8 +116,11 @@ class PublicMessageHandler : // IAMPublicService interface grpc::Status GetNodeInfo( grpc::ServerContext* context, const google::protobuf::Empty* request, iamproto::NodeInfo* response) override; - grpc::Status GetCert(grpc::ServerContext* context, const iamproto::GetCertRequest* request, - iamproto::GetCertResponse* response) override; + grpc::Status GetCert( + grpc::ServerContext* context, const iamproto::GetCertRequest* request, iamproto::CertInfo* response) override; + grpc::Status SubscribeCertChanged(grpc::ServerContext* context, + const iamanager::v5::SubscribeCertChangedRequest* request, + grpc::ServerWriter* writer) override; // IAMPublicIdentityService interface grpc::Status GetSystemInfo( @@ -223,9 +153,12 @@ class PublicMessageHandler : aos::iam::nodemanager::NodeManagerItf* mNodeManager = nullptr; aos::iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; NodeController* mNodeController = nullptr; - ServerWriterController mNodeChangedController; - ServerWriterController mSubjectsChangedController; + StreamWriter mNodeChangedController; + StreamWriter mSubjectsChangedController; aos::NodeInfo mNodeInfo; + + std::vector> mCertWriters; + std::mutex mCertWritersLock; }; #endif diff --git a/src/iamserver/streamwriter.hpp b/src/iamserver/streamwriter.hpp new file mode 100644 index 00000000..4354ff7a --- /dev/null +++ b/src/iamserver/streamwriter.hpp @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2024 Renesas Electronics Corporation. + * Copyright (C) 2024 EPAM Systems, Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef STREAMWRITER_HPP_ +#define STREAMWRITER_HPP_ + +#include + +/** + * Controls writes to streams. + */ +template +class StreamWriter { +public: + /** + * Closes all streams. + */ + void Start() + { + std::unique_lock lock {mMutex}; + + mIsRunning = true; + mNotificationID = 0; + } + + /** + * Closes all streams. + */ + void Close() + { + { + std::unique_lock lock {mMutex}; + + mIsRunning = false; + mLastMessage.reset(); + } + + mCV.notify_all(); + } + + /** + * Writes notification message to all streams. + * + * @param message notification message. + */ + void WriteToStreams(const T& message) + { + { + std::unique_lock lock {mMutex}; + + ++mNotificationID; + mLastMessage = message; + } + + mCV.notify_all(); + } + + /** + * Handles stream. Blocks the caller until the stream is closed. + * + * @param context server context. + * @param writer server writer. + * @return grpc::Status. + */ + grpc::Status HandleStream(grpc::ServerContext* context, grpc::ServerWriter* writer) + { + uint32_t lastNotificationID = 0; + + while (mIsRunning && !context->IsCancelled()) { + std::shared_lock lock {mMutex}; + + bool res = mCV.wait_for(lock, cWaitTimeout, [this, lastNotificationID] { + return (mNotificationID != lastNotificationID && mLastMessage.has_value()) || !mIsRunning; + }); + + if (!mIsRunning) { + break; + } + + if (res) { + // got notification, send it to the client + if (!writer->Write(*mLastMessage)) { + break; + } + + lastNotificationID = mNotificationID; + } + } + + return grpc::Status::OK; + } + +private: + static constexpr auto cWaitTimeout = std::chrono::seconds(10); + + std::atomic_bool mIsRunning = true; + std::condition_variable_any mCV; + std::shared_mutex mMutex; + uint32_t mNotificationID = 0; + std::optional mLastMessage; +}; + +/** + * Sends certificate updates to GRPC streams. + */ +class CertWriter : public StreamWriter, public aos::iam::certhandler::CertReceiverItf { +public: + /** + * CertWriter constructor. + * + * @param certType certificate type. + */ + explicit CertWriter(const std::string& certType) + : mCertType(certType) + { + } + +private: + void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override + { + iamanager::v5::CertInfo grpcCertInfo; + + grpcCertInfo.set_type(mCertType); + grpcCertInfo.set_key_url(info.mKeyURL.CStr()); + grpcCertInfo.set_cert_url(info.mCertURL.CStr()); + + WriteToStreams(grpcCertInfo); + } + + std::string mCertType; +}; + +#endif diff --git a/tests/iamserver/publicmessagehandler_test.cpp b/tests/iamserver/publicmessagehandler_test.cpp index ae55474f..65165da2 100644 --- a/tests/iamserver/publicmessagehandler_test.cpp +++ b/tests/iamserver/publicmessagehandler_test.cpp @@ -158,9 +158,9 @@ TEST_F(PublicMessageHandlerTest, GetCertSucceeds) auto clientStub = CreateClientStub(); ASSERT_NE(clientStub, nullptr) << "Failed to create client stub"; - grpc::ClientContext context; - iamproto::GetCertRequest request; - iamproto::GetCertResponse response; + grpc::ClientContext context; + iamproto::GetCertRequest request; + iamproto::CertInfo response; request.set_issuer("test-issuer"); request.set_serial("58bdb46d06865f7f"); @@ -193,9 +193,9 @@ TEST_F(PublicMessageHandlerTest, GetCertFails) auto clientStub = CreateClientStub(); ASSERT_NE(clientStub, nullptr) << "Failed to create client stub"; - grpc::ClientContext context; - iamproto::GetCertRequest request; - iamproto::GetCertResponse response; + grpc::ClientContext context; + iamproto::GetCertRequest request; + iamproto::CertInfo response; request.set_issuer("test-issuer"); request.set_serial("58bdb46d06865f7f"); @@ -218,6 +218,72 @@ TEST_F(PublicMessageHandlerTest, GetCertFails) ASSERT_FALSE(status.ok()); } +TEST_F(PublicMessageHandlerTest, SubscribeCertChangedSucceeds) +{ + auto clientStub = CreateClientStub(); + ASSERT_NE(clientStub, nullptr) << "Failed to create client stub"; + + grpc::ClientContext context; + iamproto::SubscribeCertChangedRequest request; + iamanager::v5::CertInfo response; + + request.set_type("test-type"); + + aos::iam::certhandler::CertInfo certInfo; + certInfo.mKeyURL = "test-key-url"; + certInfo.mCertURL = "test-cert-url"; + + EXPECT_CALL(mProvisionManager, SubscribeCertChanged) + .WillOnce(Invoke([&certInfo](const aos::String&, aos::iam::certhandler::CertReceiverItf& receiver) { + receiver.OnCertChanged(certInfo); + + return aos::ErrorEnum::eNone; + })); + + auto reader = clientStub->SubscribeCertChanged(&context, request); + + ASSERT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.type(), request.type()); + EXPECT_EQ(response.key_url(), certInfo.mKeyURL.CStr()); + EXPECT_EQ(response.cert_url(), certInfo.mCertURL.CStr()); + + context.TryCancel(); + + auto status = reader->Finish(); + + ASSERT_EQ(status.error_code(), grpc::StatusCode::CANCELLED) + << "Stream finish should return CANCELLED code: code = " << status.error_code() + << ", message = " << status.error_message(); +} + +TEST_F(PublicMessageHandlerTest, SubscribeCertChangedFailed) +{ + auto clientStub = CreateClientStub(); + ASSERT_NE(clientStub, nullptr) << "Failed to create client stub"; + + grpc::ClientContext context; + iamproto::SubscribeCertChangedRequest request; + iamanager::v5::CertInfo response; + + request.set_type("test-type"); + + EXPECT_CALL(mProvisionManager, SubscribeCertChanged) + .WillOnce(Invoke( + [](const aos::String&, aos::iam::certhandler::CertReceiverItf&) { return aos::ErrorEnum::eFailed; })); + + auto reader = clientStub->SubscribeCertChanged(&context, request); + + ASSERT_FALSE(reader->Read(&response)); + + context.TryCancel(); + + auto status = reader->Finish(); + + ASSERT_EQ(status.error_code(), grpc::StatusCode::CANCELLED) + << "Stream finish should return CANCELLED code: code = " << status.error_code() + << ", message = " << status.error_message(); +} + /*********************************************************************************************************************** * IAMPublicIdentityService tests **********************************************************************************************************************/ diff --git a/tests/include/mocks/certhandlermock.hpp b/tests/include/mocks/certhandlermock.hpp index 15ac32ff..2cb14bcb 100644 --- a/tests/include/mocks/certhandlermock.hpp +++ b/tests/include/mocks/certhandlermock.hpp @@ -27,6 +27,10 @@ class CertHandlerItfMock : public aos::iam::certhandler::CertHandlerItf { MOCK_METHOD(aos::Error, GetCertificate, (const aos::String&, const aos::Array&, const aos::Array&, aos::iam::certhandler::CertInfo&), (override)); + MOCK_METHOD(aos::Error, SubscribeCertChanged, + (const aos::String& certType, aos::iam::certhandler::CertReceiverItf& certReceiver), (override)); + MOCK_METHOD( + aos::Error, UnsubscribeCertChanged, (aos::iam::certhandler::CertReceiverItf & certReceiver), (override)); MOCK_METHOD(aos::Error, CreateSelfSignedCert, (const aos::String&, const aos::String&), (override)); MOCK_METHOD(aos::RetWithError, GetModuleConfig, (const aos::String&), (const, override)); diff --git a/tests/include/mocks/provisionmanagermock.hpp b/tests/include/mocks/provisionmanagermock.hpp index 123e3a2d..7c35bfa0 100644 --- a/tests/include/mocks/provisionmanagermock.hpp +++ b/tests/include/mocks/provisionmanagermock.hpp @@ -28,6 +28,9 @@ class ProvisionManagerMock : public ProvisionManagerItf { (const String& certType, const Array& issuer, const Array& serial, certhandler::CertInfo& resCert), (override)); + MOCK_METHOD( + Error, SubscribeCertChanged, (const String& certType, certhandler::CertReceiverItf& certReceiver), (override)); + MOCK_METHOD(Error, UnsubscribeCertChanged, (certhandler::CertReceiverItf & certReceiver), (override)); MOCK_METHOD(Error, FinishProvisioning, (const String& password), (override)); MOCK_METHOD(Error, Deprovision, (const String& password), (override)); };