Skip to content

Commit

Permalink
[iamserver] Restart iamserver after certificate renew
Browse files Browse the repository at this point in the history
Signed-off-by: Mykola Kobets <[email protected]>
  • Loading branch information
mykola-kobets-epam committed Oct 4, 2024
1 parent e8b6e7f commit 9446a13
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 125 deletions.
97 changes: 72 additions & 25 deletions src/iamserver/iamserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
{
LOG_DBG() << "IAM Server init";

mConfig = config;
mConfig = config;
mCertLoader = &certLoader;
mCryptoProvider = &cryptoProvider;

aos::Error err;
aos::NodeInfo nodeInfo;
Expand All @@ -118,8 +120,6 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
}

try {
std::shared_ptr<grpc::ServerCredentials> publicOpt, protectedOpt;

if (!provisioningMode) {
aos::iam::certhandler::CertInfo certInfo;

Expand All @@ -128,16 +128,21 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand
return AOS_ERROR_WRAP(err);
}

publicOpt = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider);
protectedOpt = aos::common::utils::GetMTLSServerCredentials(
err = certHandler.SubscribeCertReceiver(aos::String(mConfig.mCertStorage.c_str()), *this);
if (!err.IsNone()) {
return AOS_ERROR_WRAP(err);
}

mPublicCred = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider);
mProtectedCred = aos::common::utils::GetMTLSServerCredentials(
certInfo, mConfig.mCACert.c_str(), certLoader, cryptoProvider);
} else {
publicOpt = grpc::InsecureServerCredentials();
protectedOpt = grpc::InsecureServerCredentials();
mPublicCred = grpc::InsecureServerCredentials();
mProtectedCred = grpc::InsecureServerCredentials();
}

CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), publicOpt);
CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), protectedOpt);
Start();

} catch (const aos::common::utils::AosException& e) {
return e.GetError();
} catch (const std::exception& e) {
Expand Down Expand Up @@ -206,22 +211,7 @@ void IAMServer::OnNodeRemoved(const aos::String& id)

IAMServer::~IAMServer()
{
LOG_DBG() << "IAM Server shutdown";

mNodeController.Close();

if (mPublicServer) {
mPublicServer->Shutdown();
mPublicServer->Wait();
}

if (mProtectedServer) {
mProtectedServer->Shutdown();
mProtectedServer->Wait();
}

mPublicMessageHandler.Close();
mProtectedMessageHandler.Close();
Shutdown();
}

/***********************************************************************************************************************
Expand All @@ -242,6 +232,63 @@ aos::Error IAMServer::SubjectsChanged(const aos::Array<aos::StaticString<aos::cS
return aos::ErrorEnum::eNone;
}

void IAMServer::OnCertChanged(const aos::iam::certhandler::CertInfo& info)
{
mPublicCred = aos::common::utils::GetTLSServerCredentials(info, *mCertLoader, *mCryptoProvider);
mProtectedCred
= aos::common::utils::GetMTLSServerCredentials(info, mConfig.mCACert.c_str(), *mCertLoader, *mCryptoProvider);

// postpone restart so it didn't block ApplyCert
mCertChangedResult = std::async(std::launch::async, [this]() {
sleep(1);
Shutdown();
Start();
});
}

void IAMServer::Start()
{
if (mIsStarted) {
return;
}

LOG_DBG() << "IAM Server start";

mPublicMessageHandler.Start();
mProtectedMessageHandler.Start();

CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), mPublicCred);
CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), mProtectedCred);

mIsStarted = true;
}

void IAMServer::Shutdown()
{
if (!mIsStarted) {
return;
}

LOG_DBG() << "IAM Server shutdown";

mNodeController.Close();

if (mPublicServer) {
mPublicServer->Shutdown();
mPublicServer->Wait();
}

if (mProtectedServer) {
mProtectedServer->Shutdown();
mProtectedServer->Wait();
}

mPublicMessageHandler.Close();
mProtectedMessageHandler.Close();

mIsStarted = false;
}

void IAMServer::CreatePublicServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials)
{
LOG_DBG() << "Process create public server: URL=" << addr.c_str();
Expand Down
27 changes: 21 additions & 6 deletions src/iamserver/iamserver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
*/
class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf,
public aos::iam::identhandler::SubjectsObserverItf,
public aos::iam::provisionmanager::ProvisionManagerCallbackItf {
public aos::iam::provisionmanager::ProvisionManagerCallbackItf,
private aos::iam::certhandler::CertReceiverItf {
public:
/**
* Constructor.
Expand Down Expand Up @@ -115,15 +116,29 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf,
// identhandler::SubjectsObserverItf interface
aos::Error SubjectsChanged(const aos::Array<aos::StaticString<aos::cSubjectIDLen>>& messages) override;

// certhandler::CertReceiverItf interface
void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override;

// lifecycle routines
void Start();
void Shutdown();

// creating routines
void CreatePublicServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials);
void CreateProtectedServer(const std::string& addr, const std::shared_ptr<grpc::ServerCredentials>& credentials);

Config mConfig;
NodeController mNodeController;
PublicMessageHandler mPublicMessageHandler;
ProtectedMessageHandler mProtectedMessageHandler;
std::unique_ptr<grpc::Server> mPublicServer, mProtectedServer;
Config mConfig;
aos::cryptoutils::CertLoader* mCertLoader;
aos::crypto::x509::ProviderItf* mCryptoProvider;

NodeController mNodeController;
PublicMessageHandler mPublicMessageHandler;
ProtectedMessageHandler mProtectedMessageHandler;
std::unique_ptr<grpc::Server> mPublicServer, mProtectedServer;
std::shared_ptr<grpc::ServerCredentials> mPublicCred, mProtectedCred;

bool mIsStarted = false;
std::future<void> mCertChangedResult;
};

#endif
2 changes: 1 addition & 1 deletion src/iamserver/protectedmessagehandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
*/
class ProtectedMessageHandler :
// public services
private PublicMessageHandler,
public PublicMessageHandler,
// protected services
private iamproto::IAMNodesService::Service,
private iamproto::IAMProvisioningService::Service,
Expand Down
40 changes: 39 additions & 1 deletion src/iamserver/publicmessagehandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,28 @@ aos::Error PublicMessageHandler::SubjectsChanged(const aos::Array<aos::StaticStr
return aos::ErrorEnum::eNone;
}

void PublicMessageHandler::Start()
{
mNodeChangedController.Start();
mSubjectsChangedController.Start();
}

void PublicMessageHandler::Close()
{
LOG_DBG() << "Close message handler: handler=public";

mNodeChangedController.Close();
mSubjectsChangedController.Close();

{
std::lock_guard lock {mCertWritersLock};

for (auto& certWriter : mCertWriters) {
certWriter->Close();
}

mCertWriters.clear();
}
}

/***********************************************************************************************************************
Expand Down Expand Up @@ -152,7 +168,7 @@ grpc::Status PublicMessageHandler::GetNodeInfo([[maybe_unused]] grpc::ServerCont
}

grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext* context,
const iamproto::GetCertRequest* request, iamproto::GetCertResponse* response)
const iamproto::GetCertRequest* request, iamproto::CertInfo* response)
{
LOG_DBG() << "Process get cert request: type=" << request->type().c_str()
<< ", serial=" << request->serial().c_str();
Expand Down Expand Up @@ -185,6 +201,28 @@ grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext*
return grpc::Status::OK;
}

grpc::Status PublicMessageHandler::SubscribeCertChanged([[maybe_unused]] grpc::ServerContext* context,
const iamanager::v5::SubscribeCertChangedRequest* request, grpc::ServerWriter<iamanager::v5::CertInfo>* writer)
{
LOG_DBG() << "Process subscribe cert changed: type=" << request->type().c_str();

auto certWriter = std::make_shared<CertWriter>(request->type());
{
std::lock_guard lock {mCertWritersLock};

mCertWriters.push_back(certWriter);
}

auto err = GetProvisionManager()->SubscribeCertReceiver(request->type().c_str(), *certWriter);
if (!err.IsNone()) {
LOG_ERR() << "Failed to get cert: " << err;

return utils::ConvertAosErrorToGrpcStatus(err);
}

return certWriter->HandleStream(context, writer);
}

/***********************************************************************************************************************
* IAMPublicIdentityService implementation
**********************************************************************************************************************/
Expand Down
99 changes: 16 additions & 83 deletions src/iamserver/publicmessagehandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <iamanager/version.grpc.pb.h>

#include "nodecontroller.hpp"
#include "streamwriter.hpp"

/**
* Public message handler. Responsible for handling public IAM services.
Expand Down Expand Up @@ -88,90 +89,16 @@ class PublicMessageHandler :
aos::Error SubjectsChanged(const aos::Array<aos::StaticString<aos::cSubjectIDLen>>& messages) override;

/**
* Closes public message handler.
* Start public message handler.
*/
void Close();
void Start();

protected:
/**
* Server writer controller handles server writer streams.
* Closes public message handler.
*/
template <typename T>
class ServerWriterController {
public:
/**
* Closes all streams.
*/
void Close()
{
mIsRunning = false;

{
std::unique_lock lock {mMutex};

mLastMessage.reset();
}

mCV.notify_all();
}

/**
* Writes notification message to all streams.
*
* @param message notification message.
*/
void WriteToStreams(const T& message)
{
{
std::unique_lock lock {mMutex};

++mNotificationID;

mLastMessage = message;
}

mCV.notify_all();
}

/**
* Handles stream. Blocks the caller until the stream is closed.
*
* @param context server context.
* @param writer server writer.
* @return grpc::Status.
*/
grpc::Status HandleStream(grpc::ServerContext* context, grpc::ServerWriter<T>* writer)
{
uint32_t lastNotificationID = 0;

while (mIsRunning && !context->IsCancelled()) {
std::shared_lock lock {mMutex};

if (mCV.wait_for(lock, cWaitTimeout, [this, lastNotificationID] {
return mNotificationID != lastNotificationID && mLastMessage.has_value();
})) {
// got notification, send it to the client
if (!writer->Write(*mLastMessage)) {
break;
}

lastNotificationID = mNotificationID;
}
}

return grpc::Status::OK;
}

private:
static constexpr auto cWaitTimeout = std::chrono::seconds(10);

std::atomic_bool mIsRunning = true;
std::condition_variable_any mCV;
std::shared_mutex mMutex;
std::atomic_uint32_t mNotificationID = 0;
std::optional<T> mLastMessage;
};
void Close();

protected:
aos::iam::identhandler::IdentHandlerItf* GetIdentHandler() { return mIdentHandler; }
aos::iam::permhandler::PermHandlerItf* GetPermHandler() { return mPermHandler; }
aos::iam::nodeinfoprovider::NodeInfoProviderItf* GetNodeInfoProvider() { return mNodeInfoProvider; }
Expand All @@ -189,8 +116,11 @@ class PublicMessageHandler :
// IAMPublicService interface
grpc::Status GetNodeInfo(
grpc::ServerContext* context, const google::protobuf::Empty* request, iamproto::NodeInfo* response) override;
grpc::Status GetCert(grpc::ServerContext* context, const iamproto::GetCertRequest* request,
iamproto::GetCertResponse* response) override;
grpc::Status GetCert(
grpc::ServerContext* context, const iamproto::GetCertRequest* request, iamproto::CertInfo* response) override;
grpc::Status SubscribeCertChanged(grpc::ServerContext* context,
const iamanager::v5::SubscribeCertChangedRequest* request,
grpc::ServerWriter<iamanager::v5::CertInfo>* writer) override;

// IAMPublicIdentityService interface
grpc::Status GetSystemInfo(
Expand Down Expand Up @@ -223,9 +153,12 @@ class PublicMessageHandler :
aos::iam::nodemanager::NodeManagerItf* mNodeManager = nullptr;
aos::iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr;
NodeController* mNodeController = nullptr;
ServerWriterController<iamproto::NodeInfo> mNodeChangedController;
ServerWriterController<iamproto::Subjects> mSubjectsChangedController;
StreamWriter<iamproto::NodeInfo> mNodeChangedController;
StreamWriter<iamproto::Subjects> mSubjectsChangedController;
aos::NodeInfo mNodeInfo;

std::vector<std::shared_ptr<CertWriter>> mCertWriters;
std::mutex mCertWritersLock;
};

#endif
Loading

0 comments on commit 9446a13

Please sign in to comment.