Skip to content

Commit

Permalink
[iamserver] Restart iamserver after certificate renew
Browse files Browse the repository at this point in the history
Signed-off-by: Mykola Kobets <[email protected]>
  • Loading branch information
mykola-kobets-epam committed Oct 17, 2024
1 parent 85f7bd2 commit 88d9d89
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 122 deletions.
99 changes: 74 additions & 25 deletions src/iamserver/iamserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
{
LOG_DBG() << "IAM Server init";

mConfig = config;
mConfig = config;
mCertLoader = &certLoader;
mCryptoProvider = &cryptoProvider;

aos::Error err;
aos::NodeInfo nodeInfo;
Expand All @@ -118,8 +120,6 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
}

try {
std::shared_ptr<grpc::ServerCredentials> publicOpt, protectedOpt;

if (!provisioningMode) {
aos::iam::certhandler::CertInfo certInfo;

Expand All @@ -128,16 +128,21 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
return AOS_ERROR_WRAP(err);
}

publicOpt = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider);
protectedOpt = aos::common::utils::GetMTLSServerCredentials(
err = certHandler.SubscribeCertChanged(aos::String(mConfig.mCertStorage.c_str()), *this);
if (!err.IsNone()) {
return AOS_ERROR_WRAP(err);
}

mPublicCred = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider);
mProtectedCred = aos::common::utils::GetMTLSServerCredentials(
certInfo, mConfig.mCACert.c_str(), certLoader, cryptoProvider);
} else {
publicOpt = grpc::InsecureServerCredentials();
protectedOpt = grpc::InsecureServerCredentials();
mPublicCred = grpc::InsecureServerCredentials();
mProtectedCred = grpc::InsecureServerCredentials();
}

CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), publicOpt);
CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), protectedOpt);
Start();

} catch (const aos::common::utils::AosException& e) {
return e.GetError();
} catch (const std::exception& e) {
Expand Down Expand Up @@ -206,22 +211,7 @@ void IAMServer::OnNodeRemoved(const aos::String& id)

IAMServer::~IAMServer()
{
LOG_DBG() << "IAM Server shutdown";

mNodeController.Close();

if (mPublicServer) {
mPublicServer->Shutdown();
mPublicServer->Wait();
}

if (mProtectedServer) {
mProtectedServer->Shutdown();
mProtectedServer->Wait();
}

mPublicMessageHandler.Close();
mProtectedMessageHandler.Close();
Shutdown();
}

/***********************************************************************************************************************
Expand All @@ -242,6 +232,65 @@ aos::Error IAMServer::SubjectsChanged(const aos::Array<aos::StaticString<aos::cS
return aos::ErrorEnum::eNone;
}

void IAMServer::OnCertChanged(const aos::iam::certhandler::CertInfo& info)
{
mPublicCred = aos::common::utils::GetTLSServerCredentials(info, *mCertLoader, *mCryptoProvider);
mProtectedCred
= aos::common::utils::GetMTLSServerCredentials(info, mConfig.mCACert.c_str(), *mCertLoader, *mCryptoProvider);

// postpone restart so it didn't block ApplyCert
mCertChangedResult = std::async(std::launch::async, [this]() {
sleep(1);
Shutdown();
Start();
});
}

void IAMServer::Start()
{
if (mIsStarted) {
return;
}

LOG_DBG() << "IAM Server start";

mNodeController.Start();

mPublicMessageHandler.Start();
mProtectedMessageHandler.Start();

CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), mPublicCred);
CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), mProtectedCred);

mIsStarted = true;
}

void IAMServer::Shutdown()
{
if (!mIsStarted) {
return;
}

LOG_DBG() << "IAM Server shutdown";

mNodeController.Close();

mPublicMessageHandler.Close();
mProtectedMessageHandler.Close();

if (mPublicServer) {
mPublicServer->Shutdown();
mPublicServer->Wait();
}

if (mProtectedServer) {
mProtectedServer->Shutdown();
mProtectedServer->Wait();
}

mIsStarted = false;
}

void IAMServer::CreatePublicServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials)
{
LOG_DBG() << "Process create public server: URL=" << addr.c_str();
Expand Down
27 changes: 21 additions & 6 deletions src/iamserver/iamserver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
*/
class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf,
public aos::iam::identhandler::SubjectsObserverItf,
public aos::iam::provisionmanager::ProvisionManagerCallbackItf {
public aos::iam::provisionmanager::ProvisionManagerCallbackItf,
private aos::iam::certhandler::CertReceiverItf {
public:
/**
* Constructor.
Expand Down Expand Up @@ -115,15 +116,29 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf,
// identhandler::SubjectsObserverItf interface
aos::Error SubjectsChanged(const aos::Array<aos::StaticString<aos::cSubjectIDLen>>& messages) override;

// certhandler::CertReceiverItf interface
void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override;

// lifecycle routines
void Start();
void Shutdown();

// creating routines
void CreatePublicServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials);
void CreateProtectedServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials);

Config mConfig;
NodeController mNodeController;
PublicMessageHandler mPublicMessageHandler;
ProtectedMessageHandler mProtectedMessageHandler;
std::unique_ptr<grpc::Server> mPublicServer, mProtectedServer;
Config mConfig;
aos::cryptoutils::CertLoader* mCertLoader;
aos::crypto::x509::ProviderItf* mCryptoProvider;

NodeController mNodeController;
PublicMessageHandler mPublicMessageHandler;
ProtectedMessageHandler mProtectedMessageHandler;
std::unique_ptr<grpc::Server> mPublicServer, mProtectedServer;
std::shared_ptr<grpc::ServerCredentials> mPublicCred, mProtectedCred;

bool mIsStarted = false;
std::future<void> mCertChangedResult;
};

#endif
24 changes: 24 additions & 0 deletions src/iamserver/nodecontroller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,24 @@ void NodeStreamHandler::SetNodeID(const std::string& nodeID)
* Public
**********************************************************************************************************************/

NodeController::NodeController()
{
Start();
}

void NodeController::Start()
{
std::lock_guard lock {mMutex};

mIsClosed = false;
}

void NodeController::Close()
{
std::lock_guard lock {mMutex};

mIsClosed = true;

// Call Close method explicitly to avoid hanging on shutdown.
// HandleRegisterNodeStream method references handler so destructor is not called here.
for (auto& handler : mHandlers) {
Expand All @@ -355,6 +369,16 @@ void NodeController::Close()
grpc::Status NodeController::HandleRegisterNodeStream(const std::vector<aos::NodeStatus>& allowedStatuses,
NodeServerReaderWriter* stream, grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager)
{
{
std::lock_guard lock {mMutex};

if (mIsClosed) {
LOG_DBG() << "Node controller closed, cancel node registration.";

return grpc::Status::CANCELLED;
}
}

auto handler = std::make_shared<NodeStreamHandler>(allowedStatuses, stream, context, nodeManager);
StoreHandler(handler);

Expand Down
11 changes: 11 additions & 0 deletions src/iamserver/nodecontroller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ using NodeStreamHandlerPtr = std::shared_ptr<NodeStreamHandler>;
*/
class NodeController {
public:
/**
* Constructor.
*/
NodeController();

/**
* Starts node controller.
*/
void Start();

/**
* Closes all stream handlers.
*/
Expand Down Expand Up @@ -209,6 +219,7 @@ class NodeController {
void StoreHandler(NodeStreamHandlerPtr handler);
void RemoveHandler(NodeStreamHandlerPtr handler);

bool mIsClosed = false;
std::mutex mMutex;
std::vector<NodeStreamHandlerPtr> mHandlers;
};
Expand Down
2 changes: 1 addition & 1 deletion src/iamserver/protectedmessagehandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
*/
class ProtectedMessageHandler :
// public services
private PublicMessageHandler,
public PublicMessageHandler,
// protected services
private iamproto::IAMNodesService::Service,
private iamproto::IAMProvisioningService::Service,
Expand Down
57 changes: 56 additions & 1 deletion src/iamserver/publicmessagehandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,28 @@ aos::Error PublicMessageHandler::SubjectsChanged(const aos::Array<aos::StaticStr
return aos::ErrorEnum::eNone;
}

void PublicMessageHandler::Start()
{
mNodeChangedController.Start();
mSubjectsChangedController.Start();
}

void PublicMessageHandler::Close()
{
LOG_DBG() << "Close message handler: handler=public";

mNodeChangedController.Close();
mSubjectsChangedController.Close();

{
std::lock_guard lock {mCertWritersLock};

for (auto& certWriter : mCertWriters) {
certWriter->Close();
}

mCertWriters.clear();
}
}

/***********************************************************************************************************************
Expand Down Expand Up @@ -152,7 +168,7 @@ grpc::Status PublicMessageHandler::GetNodeInfo([[maybe_unused]] grpc::ServerCont
}

grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext* context,
const iamproto::GetCertRequest* request, iamproto::GetCertResponse* response)
const iamproto::GetCertRequest* request, iamproto::CertInfo* response)
{
LOG_DBG() << "Process get cert request: type=" << request->type().c_str()
<< ", serial=" << request->serial().c_str();
Expand Down Expand Up @@ -185,6 +201,45 @@ grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext*
return grpc::Status::OK;
}

grpc::Status PublicMessageHandler::SubscribeCertChanged([[maybe_unused]] grpc::ServerContext* context,
const iamanager::v5::SubscribeCertChangedRequest* request, grpc::ServerWriter<iamanager::v5::CertInfo>* writer)
{
LOG_DBG() << "Process subscribe cert changed: type=" << request->type().c_str();

auto certWriter = std::make_shared<CertWriter>(request->type());

{
std::lock_guard lock {mCertWritersLock};

mCertWriters.push_back(certWriter);
}

auto err = GetProvisionManager()->SubscribeCertChanged(request->type().c_str(), *certWriter);
if (!err.IsNone()) {
LOG_ERR() << "Failed to subscribe cert changed: " << err;

return utils::ConvertAosErrorToGrpcStatus(err);
}

auto status = certWriter->HandleStream(context, writer);

err = GetProvisionManager()->UnsubscribeCertChanged(*certWriter);
if (!err.IsNone()) {
LOG_ERR() << "Failed to unsubscribe cert changed: " << err;

return utils::ConvertAosErrorToGrpcStatus(err);
}

{
std::lock_guard lock {mCertWritersLock};

auto iter = std::remove(mCertWriters.begin(), mCertWriters.end(), certWriter);
mCertWriters.erase(iter, mCertWriters.end());
}

return status;
}

/***********************************************************************************************************************
* IAMPublicIdentityService implementation
**********************************************************************************************************************/
Expand Down
Loading

0 comments on commit 88d9d89

Please sign in to comment.