From 4a69845e0b2ffa44a96b80a792db5a9b44e98742 Mon Sep 17 00:00:00 2001 From: Richard Peters Date: Fri, 20 Sep 2024 07:44:20 +0200 Subject: [PATCH] fix: services.echo_console: parsing tokens of nested messages and arrays, formatting of messages, and properly close upon losing connection (#712) * fix(services.echo_console): parsing tokens of nested messages and arrays, formatting of messages, and properly close upon losing connection * Add hexadecimal integers, fix nested vectors and messages * services/echo_console/Console: Remove duplication --- services/echo_console/Console.cpp | 189 +++++++++++++++--------------- services/echo_console/Console.hpp | 10 +- services/echo_console/Main.cpp | 31 ++--- 3 files changed, 115 insertions(+), 115 deletions(-) diff --git a/services/echo_console/Console.cpp b/services/echo_console/Console.cpp index b9f943d07..92fcba122 100644 --- a/services/echo_console/Console.cpp +++ b/services/echo_console/Console.cpp @@ -1,5 +1,5 @@ #include "services/echo_console/Console.hpp" -#include "infra/stream/ByteOutputStream.hpp" +#include "infra/stream/StdVectorOutputStream.hpp" #include "infra/stream/StringInputStream.hpp" #include "services/tracer/GlobalTracer.hpp" #include @@ -299,14 +299,34 @@ namespace application ++parseIndex; } - while (parseIndex != line.size() && std::isdigit(line[parseIndex])) - ++parseIndex; + int32_t value = 0; + if (parseIndex != line.size() && line.substr(parseIndex, 2) == "0x") + { + tokenStart += 2; + parseIndex += 2; + while (parseIndex != line.size() && std::isxdigit(line[parseIndex])) + ++parseIndex; - std::string integer = line.substr(tokenStart, parseIndex - tokenStart); + std::string integer = line.substr(tokenStart, parseIndex - tokenStart); - int32_t value = 0; - for (std::size_t index = sign ? 1 : 0; index < integer.size(); ++index) - value = value * 10 + integer[index] - '0'; + for (std::size_t index = sign ? 1 : 0; index < integer.size(); ++index) + { + if (std::isdigit(integer[index])) + value = value * 16 + integer[index] - '0'; + else + value = value * 16 + std::tolower(integer[index]) - 'a' + 10; + } + } + else + { + while (parseIndex != line.size() && std::isdigit(line[parseIndex])) + ++parseIndex; + + std::string integer = line.substr(tokenStart, parseIndex - tokenStart); + + for (std::size_t index = sign ? 1 : 0; index < integer.size(); ++index) + value = value * 10 + integer[index] - '0'; + } if (sign) value *= -1; @@ -330,17 +350,22 @@ namespace application return ConsoleToken::String(tokenStart, identifier); } - Console::Console(EchoRoot& root) + Console::Console(EchoRoot& root, bool stopOnNetworkClose) : root(root) - , eventDispatcherThread([this]() + , eventDispatcherThread([this, stopOnNetworkClose]() { - RunEventDispatcher(); + RunEventDispatcher(stopOnNetworkClose); }) {} void Console::Run() { - while (!quit) + { + std::unique_lock lock(mutex); + started = true; + } + + while (!quit && !stoppedEventDispatcher) { std::string line; std::getline(std::cin, line); @@ -364,7 +389,7 @@ namespace application condition.notify_all(); }); - while (!processDone) + while (!processDone && !stoppedEventDispatcher) condition.wait(lock); } } @@ -592,14 +617,21 @@ namespace application std::cout << "Received method call " << methodId << " for unknown service " << serviceId << std::endl; } - void Console::RunEventDispatcher() + void Console::RunEventDispatcher(bool stopOnNetworkClose) { try { - network.Run(); + network.ExecuteUntil([this, stopOnNetworkClose]() + { + std::unique_lock lock(mutex); + return !network.NetworkActivity() && stopOnNetworkClose && started; + }); } catch (Quit&) {} + + std::unique_lock lock(mutex); + stoppedEventDispatcher = true; } void Console::ListInterfaces() @@ -745,7 +777,7 @@ namespace application MethodInvocation methodInvocation(line); auto [service, method] = SearchMethod(methodInvocation); - infra::ByteOutputStream::WithStorage<4096> stream; + infra::StdVectorOutputStream::WithStorage stream; infra::ProtoFormatter formatter(stream); formatter.PutVarInt(service->serviceId); @@ -754,8 +786,7 @@ namespace application methodInvocation.EncodeParameters(method.parameter, line.size(), formatter); } - auto range = infra::ReinterpretCastMemoryRange(stream.Writer().Processed()); - GetObserver().Send(std::string(range.begin(), range.end())); + GetObserver().Send(infra::ByteRangeAsStdString(infra::MakeRange(stream.Storage()))); } catch (ConsoleExceptions::SyntaxError& error) { @@ -767,11 +798,11 @@ namespace application } catch (ConsoleExceptions::MissingParameter& error) { - services::GlobalTracer().Trace() << "Missing parameter at index " << error.index << " (contents after that position is " << line.substr(error.index) << ")\n"; + services::GlobalTracer().Trace() << "Missing parameter at index " << error.index << " of type " << error.missingType << " (contents after that position is " << line.substr(error.index) << ")\n"; } catch (ConsoleExceptions::IncorrectType& error) { - services::GlobalTracer().Trace() << "Incorrect type at index " << error.index << " (contents after that position is " << line.substr(error.index) << ")\n"; + services::GlobalTracer().Trace() << "Incorrect type at index " << error.index << " expected type " << error.correctType << " (contents after that position is " << line.substr(error.index) << ")\n"; } catch (ConsoleExceptions::MethodNotFound& error) { @@ -942,25 +973,22 @@ namespace application if (!currentToken.Is()) throw ConsoleExceptions::SyntaxError{ IndexOf(currentToken) }; - currentToken = tokenizer.Token(); return result; } - std::vector Console::MethodInvocation::ProcessArray() + Console::MessageTokens Console::MethodInvocation::ProcessArray() { - std::vector result; + Console::MessageTokens result; while (true) { - Console::MessageTokens message; while (!currentToken.Is() && !currentToken.Is() && !currentToken.Is()) { - message.tokens.push_back(CreateMessageTokenValue()); + result.tokens.push_back(CreateMessageTokenValue()); currentToken = tokenizer.Token(); } - result.push_back(message); if (!currentToken.Is()) break; @@ -969,7 +997,6 @@ namespace application if (!currentToken.Is()) throw ConsoleExceptions::SyntaxError{ IndexOf(currentToken) }; - currentToken = tokenizer.Token(); return result; } @@ -981,7 +1008,7 @@ namespace application for (auto field : message.fields) { if (tokens.empty()) - throw ConsoleExceptions::MissingParameter{ valueIndex }; + throw ConsoleExceptions::MissingParameter{ valueIndex, field->protoType }; EncodeField(*field, tokens.front().first, tokens.front().second, formatter); tokens.erase(tokens.begin()); } @@ -1002,7 +1029,7 @@ namespace application void VisitInt64(const EchoFieldInt64& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1010,7 +1037,7 @@ namespace application void VisitUint64(const EchoFieldUint64& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1018,7 +1045,7 @@ namespace application void VisitInt32(const EchoFieldInt32& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1026,7 +1053,7 @@ namespace application void VisitFixed32(const EchoFieldFixed32& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutFixed32Field(static_cast(value.Get()), field.number); } @@ -1034,7 +1061,7 @@ namespace application void VisitFixed64(const EchoFieldFixed64& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutFixed64Field(static_cast(value.Get()), field.number); } @@ -1042,7 +1069,7 @@ namespace application void VisitBool(const EchoFieldBool& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "bool" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1050,7 +1077,7 @@ namespace application void VisitString(const EchoFieldString& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "string" }; formatter.PutStringField(infra::BoundedConstString(value.Get().data(), value.Get().size()), field.number); } @@ -1058,7 +1085,7 @@ namespace application void VisitUnboundedString(const EchoFieldUnboundedString& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "string" }; formatter.PutStringField(infra::BoundedConstString(value.Get().data(), value.Get().size()), field.number); } @@ -1066,7 +1093,7 @@ namespace application void VisitEnum(const EchoFieldEnum& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1074,7 +1101,7 @@ namespace application void VisitSFixed32(const EchoFieldSFixed32& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutFixed32Field(static_cast(value.Get()), field.number); } @@ -1082,7 +1109,7 @@ namespace application void VisitSFixed64(const EchoFieldSFixed64& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutFixed64Field(static_cast(value.Get()), field.number); } @@ -1090,55 +1117,28 @@ namespace application void VisitMessage(const EchoFieldMessage& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, field.protoType }; - methodInvocation.EncodeMessage(*field.message, value.Get(), valueIndex, formatter); + infra::StdVectorOutputStream::WithStorage stream; + infra::ProtoFormatter messageFormatter(stream); + methodInvocation.EncodeMessage(*field.message, value.Get(), valueIndex, messageFormatter); + formatter.PutLengthDelimitedField(infra::MakeRange(stream.Storage()), field.number); } void VisitBytes(const EchoFieldBytes& field) override { - if (!value.Is>()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; - std::vector bytes; - for (auto& messageTokens : value.Get>()) - { - if (messageTokens.tokens.size() < 1) - throw ConsoleExceptions::MissingParameter{ valueIndex }; - if (messageTokens.tokens.size() > 1) - throw ConsoleExceptions::TooManyParameters{ messageTokens.tokens[1].second }; - if (!messageTokens.tokens.front().first.Is()) - throw ConsoleExceptions::IncorrectType{ messageTokens.tokens[0].second }; - - bytes.push_back(static_cast(messageTokens.tokens.front().first.Get())); - } - - formatter.PutBytesField(infra::MakeRange(bytes), field.number); + PutVector(field.number); } void VisitUnboundedBytes(const EchoFieldUnboundedBytes& field) override { - if (!value.Is>()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; - std::vector bytes; - for (auto& messageTokens : value.Get>()) - { - if (messageTokens.tokens.size() < 1) - throw ConsoleExceptions::MissingParameter{ valueIndex }; - if (messageTokens.tokens.size() > 1) - throw ConsoleExceptions::TooManyParameters{ messageTokens.tokens[1].second }; - if (!messageTokens.tokens.front().first.Is()) - throw ConsoleExceptions::IncorrectType{ messageTokens.tokens[0].second }; - - bytes.push_back(static_cast(messageTokens.tokens.front().first.Get())); - } - - formatter.PutBytesField(infra::MakeRange(bytes), field.number); + PutVector(field.number); } void VisitUint32(const EchoFieldUint32& field) override { if (!value.Is()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, "integer" }; formatter.PutVarIntField(value.Get(), field.number); } @@ -1154,39 +1154,40 @@ namespace application void VisitRepeated(const EchoFieldRepeated& field) override { - if (!value.Is>()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + PutRepeated(field.protoType, field.type); + } - for (auto& messageTokens : value.Get>()) + void VisitUnboundedRepeated(const EchoFieldUnboundedRepeated& field) override + { + PutRepeated(field.protoType, field.type); + } + + private: + void PutVector(int fieldNumber) + { + if (!value.Is()) + throw ConsoleExceptions::IncorrectType{ valueIndex, "vector of integers" }; + std::vector bytes; + for (auto& messageToken : value.Get().tokens) { - if (messageTokens.tokens.size() < 1) - throw ConsoleExceptions::MissingParameter{ valueIndex }; - if (messageTokens.tokens.size() > 1) - throw ConsoleExceptions::TooManyParameters{ messageTokens.tokens[1].second }; - if (!messageTokens.tokens.front().first.Is()) - throw ConsoleExceptions::IncorrectType{ messageTokens.tokens.front().second }; + if (!messageToken.first.Is()) + throw ConsoleExceptions::IncorrectType{ messageToken.second, "integer" }; - EncodeFieldVisitor visitor(messageTokens, valueIndex, formatter, methodInvocation); - field.type->Accept(visitor); + bytes.push_back(static_cast(messageToken.first.Get())); } + + formatter.PutBytesField(infra::MakeRange(bytes), fieldNumber); } - void VisitUnboundedRepeated(const EchoFieldUnboundedRepeated& field) override + void PutRepeated(const std::string& fieldProtoType, std::shared_ptr fieldType) { if (!value.Is>()) - throw ConsoleExceptions::IncorrectType{ valueIndex }; + throw ConsoleExceptions::IncorrectType{ valueIndex, fieldProtoType }; for (auto& messageTokens : value.Get>()) { - if (messageTokens.tokens.size() < 1) - throw ConsoleExceptions::MissingParameter{ valueIndex }; - if (messageTokens.tokens.size() > 1) - throw ConsoleExceptions::TooManyParameters{ messageTokens.tokens[1].second }; - if (!messageTokens.tokens.front().first.Is()) - throw ConsoleExceptions::IncorrectType{ messageTokens.tokens.front().second }; - EncodeFieldVisitor visitor(messageTokens, valueIndex, formatter, methodInvocation); - field.type->Accept(visitor); + fieldType->Accept(visitor); } } diff --git a/services/echo_console/Console.hpp b/services/echo_console/Console.hpp index c91e276b8..9a6425deb 100644 --- a/services/echo_console/Console.hpp +++ b/services/echo_console/Console.hpp @@ -186,7 +186,7 @@ namespace application : public infra::Subject { public: - explicit Console(EchoRoot& root); + explicit Console(EchoRoot& root, bool stopOnNetworkClose); void Run(); services::ConnectionFactory& ConnectionFactory(); @@ -219,7 +219,7 @@ namespace application void ProcessParameterTokens(); std::pair CreateMessageTokenValue(); MessageTokens::MessageTokenValue ProcessMessage(); - std::vector ProcessArray(); + Console::MessageTokens ProcessArray(); void EncodeMessage(const EchoMessage& message, const MessageTokens& messageTokens, std::size_t valueIndex, infra::ProtoFormatter& formatter); void EncodeField(const EchoField& field, const MessageTokens::MessageTokenValue& value, std::size_t valueIndex, infra::ProtoFormatter& formatter); @@ -235,7 +235,7 @@ namespace application void PrintField(infra::Variant& fieldData, const EchoField& field, infra::ProtoParser& parser); void MethodNotFound(const EchoService& service, uint32_t methodId) const; void ServiceNotFound(uint32_t serviceId, uint32_t methodId) const; - void RunEventDispatcher(); + void RunEventDispatcher(bool stopOnNetworkClose); void ListInterfaces(); void ListFields(const EchoMessage& message); void Process(const std::string& line) const; @@ -246,7 +246,9 @@ namespace application main_::NetworkAdapter network; hal::TimerServiceGeneric timerService{ infra::systemTimerServiceId }; std::thread eventDispatcherThread; + bool started = false; bool quit = false; + bool stoppedEventDispatcher = false; std::mutex mutex; std::condition_variable condition; bool processDone = false; @@ -273,11 +275,13 @@ namespace application struct MissingParameter { std::size_t index; + std::string missingType; }; struct IncorrectType { std::size_t index; + std::string correctType; }; } } diff --git a/services/echo_console/Main.cpp b/services/echo_console/Main.cpp index 883346f4e..1a61b8178 100644 --- a/services/echo_console/Main.cpp +++ b/services/echo_console/Main.cpp @@ -185,7 +185,8 @@ ConsoleClientTcp::ConsoleClientTcp(services::ConnectionFactoryWithNameResolver& ConsoleClientTcp::~ConsoleClientTcp() { - consoleClientConnection->services::ConnectionObserver::Subject().AbortAndDestroy(); + if (!!consoleClientConnection) + consoleClientConnection->services::ConnectionObserver::Subject().AbortAndDestroy(); } infra::BoundedConstString ConsoleClientTcp::Hostname() const @@ -317,7 +318,8 @@ int main(int argc, char* argv[], const char* env[]) std::cout << "Loaded " << path << std::endl; } - application::Console console(root); + bool serialConnectionRequested = get(target).substr(0, 3) == "COM" || get(target).substr(0, 4) == "/dev"; + application::Console console(root, !serialConnectionRequested); services::ConnectionFactoryWithNameResolverImpl::WithStorage<4> connectionFactory(console.ConnectionFactory(), console.NameResolver()); infra::Optional consoleClientTcp; infra::Optional consoleClientWebSocket; @@ -325,24 +327,17 @@ int main(int argc, char* argv[], const char* env[]) infra::Optional> bufferedUart; infra::Optional consoleClientUart; - auto construct = [&]() + if (serialConnectionRequested) { - if (get(target).substr(0, 3) == "COM" || get(target).substr(0, 4) == "/dev") - { - uart.Emplace(get(target)); - bufferedUart.Emplace(*uart); - consoleClientUart.Emplace(console, *bufferedUart); - } - else if (services::SchemeFromUrl(infra::BoundedConstString(get(target))) == "ws") - consoleClientWebSocket.Emplace(connectionFactory, console, get(target), randomDataGenerator, tracer); - else - consoleClientTcp.Emplace(connectionFactory, console, get(target), tracer); - }; + uart.Emplace(get(target)); + bufferedUart.Emplace(*uart); + consoleClientUart.Emplace(console, *bufferedUart); + } + else if (services::SchemeFromUrl(infra::BoundedConstString(get(target))) == "ws") + consoleClientWebSocket.Emplace(connectionFactory, console, get(target), randomDataGenerator, tracer); + else + consoleClientTcp.Emplace(connectionFactory, console, get(target), tracer); - infra::EventDispatcher::Instance().Schedule([&construct]() - { - construct(); - }); console.Run(); } catch (const args::Help&)