From a23cbdab63cd1aacc37d02d5bdd069142275101b Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Thu, 4 Jan 2024 12:47:43 +0100 Subject: [PATCH 1/6] Bump `fmt`. --- 3rdparty/fmt | 2 +- src/core/table.h | 7 +++++++ src/util/fmt.h | 8 ++++++++ src/util/helpers.h | 14 ++++++++++++++ src/util/result.h | 9 +++++++++ 5 files changed, 39 insertions(+), 1 deletion(-) diff --git a/3rdparty/fmt b/3rdparty/fmt index 215f21a0..e69e5f97 160000 --- a/3rdparty/fmt +++ b/3rdparty/fmt @@ -1 +1 @@ -Subproject commit 215f21a0382d325efa66df53fbfbfddb020a2234 +Subproject commit e69e5f977d458f2650bb346dadf2ad30c5320281 diff --git a/src/core/table.h b/src/core/table.h index 77dd3aae..9775fd6d 100644 --- a/src/core/table.h +++ b/src/core/table.h @@ -552,3 +552,10 @@ inline auto ValueVectorCompare = [](const std::vector& a, const std::vect extern std::pair stringToValue(const std::string& str, value::Type type); } // namespace zeek::agent + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const zeek::agent::value::Type& t, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", to_string(t)); + } +}; diff --git a/src/util/fmt.h b/src/util/fmt.h index fb6ef9fb..0811d727 100644 --- a/src/util/fmt.h +++ b/src/util/fmt.h @@ -8,6 +8,7 @@ #include #include +#include namespace zeek::agent { @@ -33,3 +34,10 @@ std::string to_string(const T& t) { inline std::string to_string(const std::string& s) { return s; } } // namespace zeek::agent + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const nlohmann::json& json, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", json.dump()); + } +}; diff --git a/src/util/helpers.h b/src/util/helpers.h index 66967638..d0863ff3 100644 --- a/src/util/helpers.h +++ b/src/util/helpers.h @@ -392,3 +392,17 @@ inline std::ostream& operator<<(std::ostream& out, const zeek::agent::Interval& } } // namespace std::chrono + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const zeek::agent::Time& t, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", zeek::agent::to_string(t)); + } +}; + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const zeek::agent::Interval& i, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", zeek::agent::to_string(i)); + } +}; diff --git a/src/util/result.h b/src/util/result.h index 5a5a85df..c7ead099 100644 --- a/src/util/result.h +++ b/src/util/result.h @@ -5,6 +5,8 @@ #pragma once +#include "util/fmt.h" + #include #include #include @@ -184,3 +186,10 @@ class Result { }; } // namespace zeek::agent + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const zeek::agent::result::Error& err, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", err.description()); + } +}; From f5c7b0f972c59dd0d3e8183ee70423081c70c2a6 Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Thu, 4 Jan 2024 12:51:04 +0100 Subject: [PATCH 2/6] Bump `spdlog`. --- 3rdparty/spdlog | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/spdlog b/3rdparty/spdlog index d357a9e0..7e635fca 160000 --- a/3rdparty/spdlog +++ b/3rdparty/spdlog @@ -1 +1 @@ -Subproject commit d357a9e0e39263fa4acf634e74428e900ef4e9ae +Subproject commit 7e635fca68d014934b4af8a1cf874f63989352b7 From b001495f41911afc4a7b11b294a8f864a90deec7 Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Thu, 23 Feb 2023 20:59:20 +0100 Subject: [PATCH 3/6] Fix return value of scheduler loop. --- src/core/scheduler.cc | 2 +- src/main.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/scheduler.cc b/src/core/scheduler.cc index d190cfc2..d776ac8c 100644 --- a/src/core/scheduler.cc +++ b/src/core/scheduler.cc @@ -136,7 +136,7 @@ bool Scheduler::Implementation::loop() { } advance(std::chrono::system_clock::now()); - return _terminating; + return ! _terminating; } Scheduler::Scheduler() { ZEEK_AGENT_DEBUG("scheduler", "creating instance"); } diff --git a/src/main.cc b/src/main.cc index 683c7a94..75eaba1b 100644 --- a/src/main.cc +++ b/src/main.cc @@ -142,7 +142,7 @@ int zeek::agent::main(const std::vector& argv) { ZEEK_AGENT_DEBUG("main", "looping until terminated"); - while ( ! scheduler.loop() ) { + while ( scheduler.loop() ) { db.poll(); if ( zeek ) From 5c83f5cabac3f9c066186d32372098ae5931ac1a Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Thu, 4 Jan 2024 12:08:56 +0100 Subject: [PATCH 4/6] Fix help message. --- src/core/configuration.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/configuration.cc b/src/core/configuration.cc index 7e8d9eb7..9e7970f0 100644 --- a/src/core/configuration.cc +++ b/src/core/configuration.cc @@ -33,9 +33,9 @@ #include #ifndef NDEBUG -#define LOG_LEVEL_HELP "info,warning,error,critical" -#else #define LOG_LEVEL_HELP "trace,debug,info,warning,error,critical" +#else +#define LOG_LEVEL_HELP "info,warning,error,critical" #endif using namespace zeek::agent; From 5be001beea32ab4afa6262dd6988a74b12d4c0c2 Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Thu, 4 Jan 2024 08:44:10 +0100 Subject: [PATCH 5/6] Reformat Zeek scripts with `zeek-script`. --- tests/zeek/error.zeek | 2 +- tests/zeek/if-missing-table.zeek | 6 +++--- tests/zeek/query.zeek | 4 ++-- tests/zeek/requires-table.zeek | 6 +++--- tests/zeek/scheduled.zeek | 4 ++-- tests/zeek/table/files.zeek | 2 +- tests/zeek/table/processes.zeek | 2 +- tests/zeek/table/sockets.zeek | 2 +- tests/zeek/table/ssh.zeek | 2 +- tests/zeek/table/system_logs.zeek | 4 ++-- tests/zeek/table/users.zeek | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/zeek/error.zeek b/tests/zeek/error.zeek index f6fa6169..330c4f3d 100644 --- a/tests/zeek/error.zeek +++ b/tests/zeek/error.zeek @@ -15,7 +15,7 @@ event got_result() event zeek_init() { - ZeekAgent::query([$sql_stmt="SELECT foo FROM bar", $event_=got_result]); + ZeekAgent::query([ $sql_stmt="SELECT foo FROM bar", $event_=got_result ]); } event ZeekAgentAPI::agent_error_v1(ctx: ZeekAgent::Context, msg: string) diff --git a/tests/zeek/if-missing-table.zeek b/tests/zeek/if-missing-table.zeek index 22424623..a18d15dc 100644 --- a/tests/zeek/if-missing-table.zeek +++ b/tests/zeek/if-missing-table.zeek @@ -27,10 +27,10 @@ event do_terminate() event zeek_init() { - ZeekAgent::query([$sql_stmt="SELECT agent_version FROM zeek_agent", + ZeekAgent::query([ $sql_stmt="SELECT agent_version FROM zeek_agent", $event_=got_result, $cookie="Hurz", $if_missing_tables=set( - "zeek_agent")]); - schedule 5 secs { do_terminate() }; + "zeek_agent") ]); + schedule 5secs { do_terminate() }; } event ZeekAgentAPI::agent_error_v1(ctx: ZeekAgent::Context, msg: string) diff --git a/tests/zeek/query.zeek b/tests/zeek/query.zeek index 97ecf073..9eb00705 100644 --- a/tests/zeek/query.zeek +++ b/tests/zeek/query.zeek @@ -22,6 +22,6 @@ event got_result(ctx: ZeekAgent::Context, data: Columns) event zeek_init() { - ZeekAgent::query([$sql_stmt="SELECT agent_version FROM zeek_agent", - $event_=got_result, $cookie="Hurz", $schedule_=20 secs]); + ZeekAgent::query([ $sql_stmt="SELECT agent_version FROM zeek_agent", + $event_=got_result, $cookie="Hurz", $schedule_=20secs ]); } diff --git a/tests/zeek/requires-table.zeek b/tests/zeek/requires-table.zeek index f0b6e5e6..b304a5aa 100644 --- a/tests/zeek/requires-table.zeek +++ b/tests/zeek/requires-table.zeek @@ -21,9 +21,9 @@ event got_result() event zeek_init() { - ZeekAgent::query([$sql_stmt="SELECT foo FROM bar", $event_=got_result, - $requires_tables=set("bar")]); - schedule 5 secs { do_terminate() }; + ZeekAgent::query([ $sql_stmt="SELECT foo FROM bar", $event_=got_result, + $requires_tables=set("bar") ]); + schedule 5secs { do_terminate() }; } event ZeekAgentAPI::agent_error_v1(ctx: ZeekAgent::Context, msg: string) diff --git a/tests/zeek/scheduled.zeek b/tests/zeek/scheduled.zeek index f76e2aff..8b57cccd 100644 --- a/tests/zeek/scheduled.zeek +++ b/tests/zeek/scheduled.zeek @@ -31,7 +31,7 @@ event got_result(ctx: ZeekAgent::Context, data: Columns) { ZeekAgent::cancel(query_id); print "terminating soon - there should not be another 'got result' after this"; - schedule 2 secs { do_terminate() }; + schedule 2secs { do_terminate() }; } } @@ -39,5 +39,5 @@ event zeek_init() { query_id = ZeekAgent::query([ $sql_stmt="SELECT id, agent_version FROM zeek_agent", - $event_=got_result, $cookie="Hurz", $schedule_=3 secs]); + $event_=got_result, $cookie="Hurz", $schedule_=3secs ]); } diff --git a/tests/zeek/table/files.zeek b/tests/zeek/table/files.zeek index f32e17de..c779d8b3 100644 --- a/tests/zeek/table/files.zeek +++ b/tests/zeek/table/files.zeek @@ -32,5 +32,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 2 secs { do_terminate() }; + schedule 2secs { do_terminate() }; } diff --git a/tests/zeek/table/processes.zeek b/tests/zeek/table/processes.zeek index 11f21417..d75866eb 100644 --- a/tests/zeek/table/processes.zeek +++ b/tests/zeek/table/processes.zeek @@ -32,5 +32,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 2 secs { do_terminate() }; + schedule 2secs { do_terminate() }; } diff --git a/tests/zeek/table/sockets.zeek b/tests/zeek/table/sockets.zeek index 13fb5790..048b4c15 100644 --- a/tests/zeek/table/sockets.zeek +++ b/tests/zeek/table/sockets.zeek @@ -32,5 +32,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 2 secs { do_terminate() }; + schedule 2secs { do_terminate() }; } diff --git a/tests/zeek/table/ssh.zeek b/tests/zeek/table/ssh.zeek index c98bb6af..dd0434e6 100644 --- a/tests/zeek/table/ssh.zeek +++ b/tests/zeek/table/ssh.zeek @@ -36,5 +36,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 5 secs { do_terminate() }; + schedule 5secs { do_terminate() }; } diff --git a/tests/zeek/table/system_logs.zeek b/tests/zeek/table/system_logs.zeek index d6c17cfd..b2282bf0 100644 --- a/tests/zeek/table/system_logs.zeek +++ b/tests/zeek/table/system_logs.zeek @@ -10,7 +10,7 @@ @load test-setup -redef ZeekAgent_SystemLogs::query_interval = 1 sec; +redef ZeekAgent_SystemLogs::query_interval = 1sec; # We only accept the 2nd write writer so that our output doesn't depend on # runtime duration (1st write is empty). @@ -31,5 +31,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 4 secs { do_terminate() }; + schedule 4secs { do_terminate() }; } diff --git a/tests/zeek/table/users.zeek b/tests/zeek/table/users.zeek index 77369000..db685bdd 100644 --- a/tests/zeek/table/users.zeek +++ b/tests/zeek/table/users.zeek @@ -32,5 +32,5 @@ event do_terminate() event ZeekAgentAPI::agent_hello_v1(ctx: ZeekAgent::Context, columns: ZeekAgentAPI::AgentHelloV1) { - schedule 2 secs { do_terminate() }; + schedule 2secs { do_terminate() }; } From f8a284ddbdeeaf054f392acd8ad3d86aa34b713a Mon Sep 17 00:00:00 2001 From: Robin Sommer Date: Tue, 21 Feb 2023 20:59:37 +0100 Subject: [PATCH 6/6] Support attaching interactive console to a running agent. By default, an agent now creates a UNIX socket for its console that an external client can connect to. To connect to an already running agent, execute `zeek-agent -r` on the same machine (note that client and server must be running as the same user, probably `root`). Client and server can specify a different socket path through `-s `. The environment variable `ZEEK_AGENT_SOCKET` can be set to specify the path as well. Remote consoles aren't support on Windows yet. Internally, we switch the classic local console over to the same IPC mechanism we now use for remote consoles, so that there's only one implementation. For Windows, where actual IPC isn't implemented yet, we provide a dummy implementation for this purpose that just forwards data inside the main (and only) process. Closes #8. --- .clang-tidy | 1 + src/core/configuration.cc | 43 +- src/core/configuration.h | 21 +- src/io/console.cc | 662 +++++++++++++++++++++--------- src/io/console.h | 48 ++- src/main.cc | 104 ++++- src/platform/darwin/os-log-sink.h | 2 +- src/platform/linux/platform.cc | 2 +- src/platform/windows/platform.cc | 5 +- src/util/CMakeLists.txt | 7 + src/util/ascii-table.cc | 2 +- src/util/filesystem.h | 17 + src/util/fmt.h | 22 + src/util/helpers.cc | 20 + src/util/helpers.h | 4 +- src/util/socket.cc | 90 ++++ src/util/socket.h | 189 +++++++++ src/util/socket.no-ipc.cc | 93 +++++ src/util/socket.posix.cc | 192 +++++++++ 19 files changed, 1286 insertions(+), 238 deletions(-) create mode 100644 src/util/socket.cc create mode 100644 src/util/socket.h create mode 100644 src/util/socket.no-ipc.cc create mode 100644 src/util/socket.posix.cc diff --git a/.clang-tidy b/.clang-tidy index 7f613775..0f8272e6 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -10,6 +10,7 @@ Checks: 'bugprone-*, readability-*, -bugprone-easily-swappable-parameters, + -bugprone-unchecked-optional-access, -cert-err58-cpp, -clang-analyzer-cplusplus.NewDeleteLeaks, -clang-diagnostic-c++2a-designator, diff --git a/src/core/configuration.cc b/src/core/configuration.cc index 9e7970f0..04a7ceb5 100644 --- a/src/core/configuration.cc +++ b/src/core/configuration.cc @@ -43,20 +43,23 @@ using namespace zeek::agent; options::LogLevel options::default_log_level = options::LogLevel::info; options::LogType options::default_log_type = options::LogType::Stdout; filesystem::path options::default_log_path = {}; +filesystem::path options::default_socket_file_name = "zeek-agent.$$.sock"; static struct option long_driver_options[] = { // clang-format off + {"autodoc", no_argument, nullptr, 'D'}, {"config", required_argument, nullptr, 'c'}, {"execute", required_argument, nullptr, 'e'}, {"help", no_argument, nullptr, 'h'}, {"interactive", no_argument, nullptr, 'i'}, {"log-level", required_argument, nullptr, 'L'}, - {"autodoc", no_argument, nullptr, 'D'}, + {"remote", no_argument, nullptr, 'r'}, + {"socket", no_argument, nullptr, 's'}, + {"terminate-on-disconnect", no_argument, nullptr, 'N'}, {"test", no_argument, nullptr, 'T'}, {"use-mock-data", no_argument, nullptr, 'M'}, - {"terminate-on-disconnect", no_argument, nullptr, 'N'}, - {"zeek", required_argument, nullptr, 'z'}, {"version", no_argument, nullptr, 'v'}, + {"zeek", required_argument, nullptr, 'z'}, {nullptr, 0, nullptr, 0} // clang-format on }; @@ -64,6 +67,12 @@ static struct option long_driver_options[] = { static void usage(const filesystem::path& name) { auto cfg = platform::configurationFile() ? platform::configurationFile()->string() : std::string("n/a"); + auto options = Options::default_(); + std::string socket = "n/a"; + + if ( options.socket ) + socket = options.socket->string(); + // clang-format off std::cerr << "\nUsage: " << name.filename().string() << frmt( " [options]\n" @@ -74,12 +83,14 @@ static void usage(const filesystem::path& name) { " -N | --terminate-on-disconnect Terminate when remote side disconnects (for testing)\n" " -T | --test Run unit tests and exit\n" " -c | --config Load configuration from file [default: {}]\n" - " -e | --execute SQL statement to execute immediately, then quit" + " -e | --execute SQL statement to execute immediately, then quit\n" " -h | --help Show usage information\n" " -i | --interactive Spawn interactive console\n" + " -r | --remote Connect interactive console to already running agent\n" + " -s | --socket Specify socket to use for console communication [default: {}]\n" " -v | --version Print version information\n" " -z | --zeek [:port] Connect to Zeek at given address\n" - "\n", cfg); + "\n", cfg, socket); // clang-format on } @@ -101,8 +112,8 @@ Result Options::parseArgv(const std::vector& argv) { #endif while ( true ) { - int c = - getopt_long(static_cast(argv_.size()), argv_.data(), "DL:MNTc:e:hivz:", long_driver_options, nullptr); + int c = getopt_long(static_cast(argv_.size()), argv_.data(), "DL:MNTc:e:hirs:vz:", long_driver_options, + nullptr); if ( c < 0 ) return Nothing(); @@ -129,6 +140,8 @@ Result Options::parseArgv(const std::vector& argv) { case 'c': config_file = optarg; break; case 'e': execute = optarg; break; case 'i': interactive = true; break; + case 'r': mode = options::Mode::RemoteConsole; break; + case 's': socket = filesystem::path(optarg); break; case 'z': zeek_destinations.emplace_back(optarg); break; case 'v': std::cerr << "Zeek Agent v" << VersionLong << std::endl; exit(0); @@ -152,6 +165,7 @@ void Options::debugDump() const { (log_level ? options::to_string(*log_level) : "")); ZEEK_AGENT_DEBUG("configuration", "[option] log.type: {}", (log_type ? to_string(*log_type) : "")); ZEEK_AGENT_DEBUG("configuration", "[option] log.path: {}", (log_path ? log_path->string() : "")); + ZEEK_AGENT_DEBUG("configuration", "[option] socket: {}", (socket ? socket->string() : "")); ZEEK_AGENT_DEBUG("configuration", "[option] use-mock-data: {}", use_mock_data); ZEEK_AGENT_DEBUG("configuration", "[option] terminate-on-disconnect: {}", terminate_on_disconnect); ZEEK_AGENT_DEBUG("configuration", "[option] zeek.groups: {}", join(zeek_groups, ", ")); @@ -240,6 +254,21 @@ Options Options::default_() { if ( path && filesystem::is_regular_file(*path) ) options.config_file = *path; +#ifndef HAVE_WINDOWS + if ( ! options.socket ) { + const char* env = getenv("ZEEK_AGENT_SOCKET"); + if ( env && *env ) + options.socket = env; + else { + filesystem::path socket_dir = "/tmp"; + if ( auto d = platform::dataDirectory() ) + socket_dir = *d; + + options.socket = socket_dir / replace(options::default_socket_file_name, "$$", frmt("{}", getuid())); + } + } +#endif + return options; } diff --git a/src/core/configuration.h b/src/core/configuration.h index 1ecfa6f1..f4a6d1b8 100644 --- a/src/core/configuration.h +++ b/src/core/configuration.h @@ -22,13 +22,15 @@ namespace options { * couple of special modes beyond normal operation. **/ enum class Mode { - Standard, /**< normal operation */ - Test, /**< run unit tests and exit */ - AutoDoc /**< print out JSON describing table schemas and exit */ + Standard, /**< normal operation */ + RemoteConsole, /**< connect to remote agent */ + Test, /**< run unit tests and exit */ + AutoDoc /**< print out JSON describing table schemas and exit */ }; inline std::string to_string(options::Mode mode) { switch ( mode ) { + case options::Mode::RemoteConsole: return "remote console"; case options::Mode::Standard: return "standard"; case options::Mode::Test: return "test"; case options::Mode::AutoDoc: return "autodoc"; @@ -87,6 +89,7 @@ inline Result from_str(const std::string_view& t) { extern LogLevel default_log_level; extern LogType default_log_type; extern filesystem::path default_log_path; +extern filesystem::path default_socket_file_name; } // namespace options @@ -123,9 +126,16 @@ struct Options { /** Console statement/command to execute at startup, and then terminate */ std::string execute; - /** True to spawn the interactive console */ + /** True to spawn the interactive console locally. */ bool interactive = false; + /** + * Set to a socket name to spawn the interactive console connecting to a + * remote agent. If set to an empty path, expect remote at default socket + * location. + */ + std::optional interactive_remote; + /** The agent's level of logging. Default is `warn` and worse. */ std::optional log_level; @@ -135,6 +145,9 @@ struct Options { /** File path associated with logger, if current type needs one. */ std::optional log_path = {}; + /** Default socket for remote console. */ + std::optional socket; + /** True to have any tables only report mock data for testing. */ bool use_mock_data = false; diff --git a/src/io/console.cc b/src/io/console.cc index cca59748..56f63918 100644 --- a/src/io/console.cc +++ b/src/io/console.cc @@ -9,13 +9,14 @@ #include "util/color.h" #include "util/fmt.h" #include "util/helpers.h" +#include "util/socket.h" +#include "util/testing.h" #include #include #include -#include +#include #include -#include #include #include #include @@ -24,151 +25,318 @@ using namespace zeek::agent; +static const auto ProtocolVersion = "1"; +static const auto MessageTerminator = "<<<>>>"; + +// Helpers sending output back to client. +static void sendEndOfMessage(socket::Remote& remote); +static void sendError(socket::Remote& remote, const std::string& msg); +static void sendHelp(socket::Remote& remote); +static void sendResult(socket::Remote& remote, const query::Result& result, bool include_type); +static void sendSchema(socket::Remote& remote, const Schema& schema); +static void sendTables(socket::Remote& remote, const std::map& tables); +static void sendWelcome(socket::Remote& remote); + +// State of worker thread serving an ongoing query. +struct PendingQuery { + socket::Remote remote; // remote console to send output to + std::unique_ptr thread; // worker thread serving query + std::optional query_id; // while a query is running, its ID + std::mutex done_mutex; // mutex to flag when a query has been fully processed + std::condition_variable done_cv; // condition variable to flag when a query has been fully processed + std::atomic needs_join = false; // flag to server thread to clean up thread +}; + template<> -struct Pimpl::Implementation { +struct Pimpl::Implementation { // One time initialization from main thread. void init(); - // Main interactive loop running inside thread. - void repl(); + // Clean up any state before destruction. + void done(); - // Executes a command or query. - void execute(const std::string& cmd, bool terminate = false); + // Main loop for the console server running inside its own thread; will + // block until terminated. + void poll(); - // Performance a query against the database. - void query(const std::string& stmt, std::optional subscription, bool terminate = false); + // Executes a command/query entered on the console. + void execute(socket::Remote remote, const std::string& cmd, bool terminate = false); - // Prints a message to the console. - void message(const std::string& msg); + // Performs a query against the database. This returns immediately after + // spawning a worker thread performing the query. + void query(const socket::Remote& remote, const std::string& stmt, + std::optional subscription, bool terminate = false); - // Prints an error to the console. - void error(const std::string& msg); + filesystem::path _socket_path; // as passed into constructor + Database* _db = nullptr; // as passed into constructor + Scheduler* _scheduler = nullptr; // as passed into constructor - // Prints a query result to the console. - void printResult(const query::Result& result, bool include_type); + Socket _socket; // IPC socket for communicating with console clients + std::map _tables; // copy of table schema for thread-safety + std::unique_ptr _thread; // console's thread + std::list> _pending_queries; // queries waiting for results +}; - // Prints a liste of all tables. - void printTables(); +void ConsoleServer::Implementation::init() { + for ( auto t : _db->tables() ) + // Create a copy of the table schema while we are in the main thread. + _tables.emplace(t->name(), t->schema()); - // Prints the schema for a table. - Result printSchema(const std::string& table); + if ( auto rc = _socket.bind(_socket_path) ) + ZEEK_AGENT_DEBUG("console-server", "opened socket {}", _socket_path); + else + logger()->warn("console server: {}", rc.error()); +} - // Prints a help test to the console. - void help(); +void ConsoleServer::Implementation::done() { + filesystem::remove(_socket_path); - // Prints an initial welcome message to the console. - void welcome(); + for ( auto& p : _pending_queries ) { + p->done_cv.notify_all(); + p->thread->join(); + } - Database* _db = nullptr; // as passed into constructor - Scheduler* _scheduler = nullptr; // as passed into constructor - SignalManager* _signal_mgr = nullptr; // as passed into constructor + _thread->join(); +} - std::string _scheduled_statement; // pre-scheduled statement +void ConsoleServer::Implementation::poll() { + if ( ! _socket ) + return; - std::map _tables; // copy of table schema for thread-safety + ZEEK_AGENT_DEBUG("console-server", "reading from socket {}", _socket_path); - std::unique_ptr _thread; // console's thread - std::optional _current_query; // while a query is running, its ID - std::mutex _query_done_mutex; // mutex to flag when a query has been fully processed - std::condition_variable _query_done_cv; // condition variable to flag when a query has been fully processed - replxx::Replxx _rx; // instance of the REPL -}; + while ( ! _scheduler->terminating() ) { + auto result = _socket.read(); + if ( ! result ) { + logger()->warn("console server: receive failed: {}", result.error()); + continue; + } -void Console::Implementation::init() { - for ( auto t : _db->tables() ) - // Create a copy of the table schema while we are in the main thread. - _tables.emplace(t->name(), t->schema()); -} + if ( ! result->has_value() ) + continue; + + auto msg = trim((*result)->first); + auto remote = (*result)->second; -void Console::Implementation::execute(const std::string& cmd, bool terminate) { - ZEEK_AGENT_DEBUG("console", "executing: {}", cmd); + if ( msg == MessageTerminator ) + break; - auto check_terminate = [&]() { - if ( terminate ) - _scheduler->terminate(); - }; + ZEEK_AGENT_DEBUG("console-server", "received command: {}", msg); + execute(remote, msg, false); - if ( cmd == ".tables" ) { - printTables(); - check_terminate(); + // Clean up worker threads. + auto i = _pending_queries.begin(); + while ( i != _pending_queries.end() ) { + auto cur = i++; + if ( (*cur)->needs_join ) { + ZEEK_AGENT_DEBUG("console-server", "joining query worker thread"); + (*cur)->thread->join(); + _pending_queries.erase(cur); + } + } } - else if ( cmd == ".quit" || cmd == ".exit" ) - _scheduler->terminate(); + ZEEK_AGENT_DEBUG("console-server", "done reading from socket {}", _socket_path); +} - else if ( cmd == ".help" ) { - help(); - check_terminate(); +void ConsoleServer::Implementation::execute(socket::Remote remote, const std::string& cmd, bool terminate) { + // We perform anything that's quick synchronously inside the main console + // server thread. For queries, we spawn a worker thread to serve them. + if ( cmd == ".version" ) { + remote << ProtocolVersion << std::endl; + sendEndOfMessage(remote); } + else if ( cmd == ".welcome" ) + sendWelcome(remote); + + else if ( cmd == ".tables" ) + sendTables(remote, _tables); + + else if ( cmd == ".help" ) + sendHelp(remote); + + else if ( cmd == ".terminate" ) + _scheduler->terminate(); + else if ( cmd.substr(0, 7) == ".diffs " ) - query(cmd.substr(7), query::SubscriptionType::Differences); + query(remote, cmd.substr(7), query::SubscriptionType::Differences); else if ( cmd.substr(0, 21) == ".snapshot-plus-diffs " ) - query(cmd.substr(21), query::SubscriptionType::SnapshotPlusDifferences); + query(remote, cmd.substr(21), query::SubscriptionType::SnapshotPlusDifferences); else if ( cmd.substr(0, 8) == ".events " ) - query(cmd.substr(8), query::SubscriptionType::Events); + query(remote, cmd.substr(8), query::SubscriptionType::Events); else if ( cmd.substr(0, 8) == ".schema " ) { if ( auto m = split(trim(cmd.substr(8))); m.size() == 1 && ! m[0].empty() ) { - if ( auto rc = printSchema(m[0]); ! rc ) - error(rc.error()); + auto t = _tables.find(m[0]); + if ( t != _tables.end() ) + sendSchema(remote, t->second); + else + sendError(remote, "no such table"); } else - error("cannot parse table name"); - - check_terminate(); + sendError(remote, "cannot parse table name"); } else if ( cmd.substr(0, 11) == ".snapshots " ) - query(cmd.substr(11), query::SubscriptionType::Snapshots); + query(remote, cmd.substr(11), query::SubscriptionType::Snapshots); - else if ( cmd[0] == '.' ) { - error("unknown command"); - check_terminate(); + else if ( cmd == ".ctrlc" ) { + for ( auto& p : _pending_queries ) { + if ( p->remote == remote ) + p->done_cv.notify_all(); + } } + else if ( cmd[0] == '.' ) + sendError(remote, frmt("unknown command: {}", split(cmd).front())); + else - query(cmd, {}, terminate); + query(remote, cmd, {}, terminate); + + if ( auto err = remote.error() ) + logger()->warn("console send failed: {}", *err); } -void Console::Implementation::repl() { - filesystem::path history_path; +void ConsoleServer::Implementation::query(const socket::Remote& remote, const std::string& stmt, + std::optional subscription, bool terminate) { + // We spawn a worker thread here to serve the query, and then return back to caller immediately. + _pending_queries.emplace_back(std::make_unique()); + auto pending = _pending_queries.back().get(); - // Runs in its own thread. - if ( auto dir = platform::dataDirectory() ) { - history_path = *dir / "history"; - _rx.history_load(history_path.string()); - } + Query query = {.sql_stmt = stmt, + .subscription = subscription, + .schedule = 2s, + .terminate = terminate, + .cookie = "", - welcome(); + .callback_result = + [pending, subscription](query::ID id, const query::Result& result) { + sendResult(pending->remote, result, + subscription && *subscription != query::SubscriptionType::Snapshots); - while ( ! _scheduler->terminating() ) { - auto raw_input = _rx.input(color::yellow("> ")); - if ( ! raw_input ) { - if ( errno == EAGAIN ) - continue; + if ( subscription && *subscription == query::SubscriptionType::Snapshots ) + pending->remote << std::endl; + + if ( ! pending->remote ) { + // Cancel query on error, but still try to get + // the error message across. + sendError(pending->remote, frmt("console send failed: {}", *pending->remote.error())); + pending->done_cv.notify_all(); + } + }, - // EOF -> exit - raw_input = ".quit"; - } + .callback_done = + [pending](query::ID id, bool regular_shutdown) { + std::unique_lock lock(pending->done_mutex); + pending->remote << std::endl; + pending->done_cv.notify_all(); + }}; - auto input = trim(raw_input); - if ( input.empty() ) - continue; + pending->remote = remote; + pending->thread = + std::make_unique([this, pending, query = std::move(query), scheduler = _scheduler]() { + ZEEK_AGENT_DEBUG("console-server", "starting query worker thread"); - _rx.history_add(input); + std::unique_lock lock(pending->done_mutex); - if ( ! history_path.empty() ) - _rx.history_sync(history_path.string()); + scheduler->schedule([this, pending, &query]() { + std::unique_lock lock(pending->done_mutex); + + if ( auto id = _db->query(query) ) + pending->query_id = *id; + + else { + sendError(pending->remote, id.error()); + pending->done_cv.notify_all(); + } + }); + + pending->done_cv.wait(lock); + sendEndOfMessage(pending->remote); + + if ( auto err = pending->remote.error() ) + logger()->warn("console send failed: {}", *err); + + if ( pending->query_id ) { + // Move canceling of query into main thread. + auto id = *pending->query_id; + scheduler->schedule([this, id]() { _db->cancel(id); }); + pending->query_id.reset(); + } + + pending->needs_join = true; + }); +} + +ConsoleServer::ConsoleServer(const filesystem::path& socket_path, Database* db, Scheduler* scheduler) { + ZEEK_AGENT_DEBUG("console-server", "creating instance"); + pimpl()->_socket_path = socket_path; + pimpl()->_db = db; + pimpl()->_scheduler = scheduler; +} + +ConsoleServer::~ConsoleServer() { + ZEEK_AGENT_DEBUG("console-server", "destroying instance"); + stop(); +} + +void ConsoleServer::start() { + ZEEK_AGENT_DEBUG("console-server", "starting"); - execute(input, false); +#ifdef HAVE_WINDOWS + static const HANDLE handle = GetStdHandle(STD_OUTPUT_HANDLE); + DWORD flags; + GetConsoleMode(handle, &flags); + flags |= ENABLE_PROCESSED_OUTPUT; + flags |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + SetConsoleMode(handle, flags); +#endif + + pimpl()->init(); + pimpl()->_thread = std::make_unique([this]() { pimpl()->poll(); }); +} + +void ConsoleServer::stop() { + if ( pimpl()->_thread ) { + ZEEK_AGENT_DEBUG("console-server", "stopping"); + pimpl()->done(); } } -void Console::Implementation::printResult(const query::Result& result, bool include_type) { +static void sendEndOfMessage(socket::Remote& remote) { remote << MessageTerminator << std::endl << std::flush; } + +static void sendError(socket::Remote& remote, const std::string& msg) { + remote << frmt("error: {}", msg) << std::endl; + sendEndOfMessage(remote); +} + +static void sendHelp(socket::Remote& remote) { + remote << R"( +Query + + Example: SELECT * FROM processes WHERE uid = 100 + +Commands + + .help display this help + .diffs continuously reschedule query, showing added or removed entries each time + .snapshot-plus-diffs show initial snapshot, then continuously reschedule query, showing added or removed entries each time + .events continuously reschedule query, showing new entries each time + .schema display the schema for the given table + .snapshots continuously reschedule query, showing all entries each time + .tables list available tables + .terminate terminate agent and quit console + .exit quit console, but leave agent running +)" << std::endl; + + sendEndOfMessage(remote); +} + +static void sendResult(socket::Remote& remote, const query::Result& result, bool include_type) { if ( result.columns.empty() ) return; @@ -201,161 +369,214 @@ void Console::Implementation::printResult(const query::Result& result, bool incl table.addRow(std::move(values)); } - table.print(std::cout, include_header); + table.print(remote, include_header); } -void Console::Implementation::printTables() { - AsciiTable out; - out.addHeader({"Name", "Description"}); - - for ( const auto& [name, schema] : _tables ) - out.addRow({name, schema.summary}); - - out.print(std::cout); -} - -Result Console::Implementation::printSchema(const std::string& table) { - auto t = _tables.find(table); - if ( t == _tables.end() ) - return result::Error("no such table"); - +static void sendSchema(socket::Remote& remote, const Schema& schema) { AsciiTable out; out.addHeader({"Column", "Type", "Description"}); bool has_parameters = false; - for ( const auto& c : t->second.columns ) { + for ( const auto& c : schema.columns ) { if ( ! c.is_parameter ) out.addRow({c.name, to_string(c.type), c.summary}); else has_parameters = true; } - std::cout << std::endl; - out.print(std::cout); - std::cout << std::endl; + remote << std::endl; + out.print(remote); + remote << std::endl; if ( has_parameters ) { AsciiTable out; out.addHeader({"Table Parameter", "Type", "Description"}); - for ( const auto& c : t->second.columns ) { + for ( const auto& c : schema.columns ) { if ( c.is_parameter ) out.addRow({ltrim(c.name, "_"), to_string(c.type), c.summary}); } - out.print(std::cout); - std::cout << std::endl; + out.print(remote); + remote << std::endl; } - return Nothing(); + sendEndOfMessage(remote); } -void Console::Implementation::query(const std::string& stmt, std::optional subscription, - bool terminate) { - Query query = {.sql_stmt = stmt, - .subscription = subscription, - .schedule = 2s, - .terminate = terminate, - .cookie = "", - .callback_result = - [&](query::ID id, const query::Result& result) { - printResult(result, subscription && *subscription != query::SubscriptionType::Snapshots); +static void sendTables(socket::Remote& remote, const std::map& tables) { + AsciiTable out; + out.addHeader({"Name", "Description"}); - if ( subscription && *subscription == query::SubscriptionType::Snapshots ) - std::cout << std::endl; - }, - .callback_done = - [&](query::ID id, bool regular_shutdown) { - std::unique_lock lock(_query_done_mutex); - _query_done_cv.notify_all(); - }}; + for ( const auto& [name, schema] : tables ) + out.addRow({name, schema.summary}); - std::unique_ptr sigint_handler; - if ( ! terminate ) - // Temporarily install our our SIGINT handler while the query is running. - sigint_handler = std::make_unique(_signal_mgr, SIGINT, [this]() { - std::unique_lock lock(_query_done_mutex); - _query_done_cv.notify_all(); - }); + out.print(remote); - { - std::unique_lock lock(_query_done_mutex); + sendEndOfMessage(remote); +} - _current_query.reset(); +static void sendWelcome(socket::Remote& remote) { + remote << R"( +Welcome to Zeek Agent v2. - _scheduler->schedule([this, query]() { - std::unique_lock lock(_query_done_mutex); +Enter query or command to execute. Type `.help` for help, and `.quit` for exit. +)" << std::endl; - if ( auto id = _db->query(query) ) - _current_query = *id; - else { - error(id.error()); - _query_done_cv.notify_all(); - } + sendEndOfMessage(remote); +} + + +///////////////////////////////// + +template<> +struct Pimpl::Implementation { + // One time initialization from main thread. + void init(); + + // Clean up any state before destruction. + void done(); + + // Main interactive loop running inside thread; won't return until termination + void repl(); + + // Executes a command or query, returns output + std::optional execute(const std::string& cmd, bool echo = true); + + filesystem::path _socket_path; // as passed into constructor + Scheduler* _scheduler = nullptr; // as passed into constructor + SignalManager* _signal_mgr = nullptr; // as passed into constructor + + std::pair + _scheduled_statement; // pre-scheduled statement; bool indicates if output is to be echoed + std::optional _scheduled_result; // output of pre-scheduled statement + + Socket _socket; // IPC socket for communicating with console server + socket::Remote _remote; // remote console server for sending commands to + std::unique_ptr _sigint; // custom CTRL-C handler while client is running + std::unique_ptr _thread; // console's thread + replxx::Replxx _rx; // instance of the REPL + std::atomic _ctrlc = 0; // number of times CTRL-C has been pressed during command execution +}; + +void ConsoleClient::Implementation::init() { + filesystem::path client_socket = frmt("{}.client", _socket_path); + if ( auto rc = _socket.bind(client_socket); ! rc ) { + logger()->error("console client: {}", rc.error()); + return; + } + + ZEEK_AGENT_DEBUG("console-client", "opened socket {}", client_socket); + + _remote = socket::Remote(&_socket, _socket_path); + ZEEK_AGENT_DEBUG("console-client", "connected to remote socket {}", _socket_path); + + if ( _signal_mgr ) + _sigint = std::make_unique(_signal_mgr, SIGINT, [this]() { + ++_ctrlc; + _socket.write(".ctrlc\n", _remote); }); +} + +void ConsoleClient::Implementation::done() { _sigint.reset(); } - _query_done_cv.wait(lock); +void ConsoleClient::Implementation::repl() { + // Runs in its own thread. - if ( _current_query ) { - // Move cancelling of query into main thread. - auto id = *_current_query; - _scheduler->schedule([this, id]() { _db->cancel(id); }); - _current_query.reset(); + filesystem::path history_path; + if ( auto dir = platform::dataDirectory() ) { + history_path = *dir / "history"; + _rx.history_load(history_path.string()); + } + + ZEEK_AGENT_DEBUG("console-client", "reading from socket {}", _socket_path); + + execute(".version", false); // No version check for now, we only have one version. + execute(".welcome"); + + while ( ! _scheduler->terminating() ) { + auto raw_input = _rx.input(color::yellow("> ")); + if ( ! raw_input ) { + if ( errno == EAGAIN ) + continue; + + // EOF -> exit + raw_input = ".quit"; } + + auto input = trim(raw_input); + + if ( input.empty() ) + continue; + + _rx.history_add(input); + + if ( ! history_path.empty() ) + _rx.history_sync(history_path.string()); + + if ( trim(input) == ".quit" || trim(input) == ".exit" ) + return; + + execute(input); } - std::cout << std::endl; + ZEEK_AGENT_DEBUG("console-client", "done reading from socket {}", _socket_path); } -void Console::Implementation::message(const std::string& msg) { _rx.print("%s\n", msg.c_str()); } +std::optional ConsoleClient::Implementation::execute(const std::string& cmd, bool echo) { + ZEEK_AGENT_DEBUG("console-client", "sending command: {}", cmd); -void Console::Implementation::error(const std::string& msg) { _rx.print("error: %s\n", msg.c_str()); } + std::optional output; -void Console::Implementation::welcome() { - _rx.print(R"( -Welcome to Zeek Agent v2. + _remote << cmd << std::endl; + if ( auto err = _remote.error() ) + throw FatalError(frmt("failed to send command to server: {}", *err)); -Enter query or command to execute. Type `.help` for help, and `.quit` for exit. + _ctrlc = 0; + while ( _ctrlc < 2 && ! _scheduler->terminating() ) { + auto result = _socket.read(); + if ( ! result ) + throw FatalError(frmt("receive failed: {}", result.error())); -)"); -} + if ( ! result->has_value() ) + continue; -void Console::Implementation::help() { - _rx.print(R"( -Query + auto msg = trim((*result)->first); + ZEEK_AGENT_DEBUG("console-client", "received output: {}", msg); - Example: SELECT * FROM processes WHERE uid = 100 + if ( msg == MessageTerminator ) + break; -Commands + if ( echo ) + _rx.print("%s", (*result)->first.c_str()); - .help display this help - .quit terminate agent - .diffs continously reschedule query, showing added or removed entries each time - .snapshot-plus-diffs show initial snapshot, then continously reschedule query, showing added or removed entries each time - .events continously reschedule query, showing new entries each time - .schema
display the schema for the given table - .snapshots continously reschedule query, showing all entries each time - .tables list available tables + if ( ! output ) + output = msg; + else + output->append(msg); + } -)"); + return output; } -Console::Console(Database* db, Scheduler* scheduler, SignalManager* signal_mgr) { - ZEEK_AGENT_DEBUG("console", "creating instance"); - pimpl()->_db = db; +ConsoleClient::ConsoleClient(const filesystem::path& socket, Scheduler* scheduler, SignalManager* signal_mgr) { + ZEEK_AGENT_DEBUG("console-client", "creating instance"); + pimpl()->_socket_path = socket; pimpl()->_scheduler = scheduler; pimpl()->_signal_mgr = signal_mgr; } -Console::~Console() { - ZEEK_AGENT_DEBUG("console", "destroying instance"); +ConsoleClient::~ConsoleClient() { + ZEEK_AGENT_DEBUG("console-client", "destroying instance"); stop(); } -void Console::scheduleStatementWithTermination(std::string stmt) { pimpl()->_scheduled_statement = std::move(stmt); } +void ConsoleClient::scheduleStatementWithTermination(std::string stmt, bool echo) { + pimpl()->_scheduled_statement = {std::move(stmt), echo}; +} -void Console::start() { - ZEEK_AGENT_DEBUG("console", "starting"); +void ConsoleClient::start(bool run_repl) { + ZEEK_AGENT_DEBUG("console-client", "starting"); #ifdef HAVE_WINDOWS static const HANDLE handle = GetStdHandle(STD_OUTPUT_HANDLE); @@ -367,19 +588,70 @@ void Console::start() { #endif pimpl()->init(); + pimpl()->_thread = std::make_unique([this, run_repl]() { + try { + if ( pimpl()->_scheduled_statement.first.size() ) + pimpl()->_scheduled_result = + pimpl()->execute(pimpl()->_scheduled_statement.first, pimpl()->_scheduled_statement.second); + else if ( run_repl ) + pimpl()->repl(); + else { + ZEEK_AGENT_DEBUG("console-client", "not starting REPL - waiting until terminated"); + while ( ! pimpl()->_scheduler->terminating() ) + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + } catch ( const FatalError& e ) { + logger()->error("{}", e.what()); + } catch ( const InternalError& e ) { + logger()->error("internal error: {}", e.what()); + } - pimpl()->_thread = std::make_unique([this]() { - if ( pimpl()->_scheduled_statement.size() ) - pimpl()->execute(pimpl()->_scheduled_statement, true); - else - pimpl()->repl(); + pimpl()->_scheduler->terminate(); }); } -void Console::stop() { +void ConsoleClient::stop() { if ( pimpl()->_thread ) { - ZEEK_AGENT_DEBUG("console", "stopping"); - pimpl()->_query_done_cv.notify_all(); + ZEEK_AGENT_DEBUG("console-client", "stopping"); pimpl()->_thread->join(); } + + pimpl()->done(); +} + +TEST_SUITE("console") { + Configuration cfg; + auto socket = filesystem::path(frmt("/tmp/zeek-agent-test-socket.{}", getpid())); + + TEST_CASE("client/server") { + Scheduler scheduler; + Database db(&cfg, &scheduler); + + ConsoleServer server(socket, &db, &scheduler); + server.start(); + + ConsoleClient client(socket, &scheduler, nullptr); + client.start(false); + + CHECK_EQ(client.pimpl()->execute(".version", false), ProtocolVersion); + scheduler.terminate(); + } + + TEST_CASE("pre-scheduled statement") { + Scheduler scheduler; + Database db(&cfg, &scheduler); + + ConsoleServer server(socket, &db, &scheduler); + server.start(); + + ConsoleClient client(socket, &scheduler, nullptr); + client.scheduleStatementWithTermination(".version", false); + client.start(); + + while ( scheduler.loop() ) + std::this_thread::sleep_for(std::chrono::microseconds(100)); + + REQUIRE(client.pimpl()->_scheduled_result); + CHECK_EQ(*client.pimpl()->_scheduled_result, ProtocolVersion); + } } diff --git a/src/io/console.h b/src/io/console.h index e0b1523d..002c3c5d 100644 --- a/src/io/console.h +++ b/src/io/console.h @@ -2,6 +2,7 @@ #pragma once +#include "util/filesystem.h" #include "util/pimpl.h" #include @@ -23,28 +24,55 @@ class SignalManager; * * All public methods are thread-safe. */ -class Console : public Pimpl { +class ConsoleServer : public Pimpl { public: /** - * Constructor. + * Constructor. TODO: Update * * @param database database to use for queries; observer only, doesn't take ownership - * @param scheduler scheduler to use for any timeers; observer only, doesn't take ownership - * @param signal_mgr signal manager to install handlers with; observer only, doesn't take ownership + * @param scheduler scheduler to use for any timers; observer only, doesn't take ownership */ - Console(Database* db, Scheduler* scheduler, SignalManager* signal_mgr); - ~Console(); + ConsoleServer(const filesystem::path& socket, Database* db, Scheduler* scheduler); + ~ConsoleServer(); + + /** Starts a console server thread. */ + void start(); + + /** Stops the console server thread. */ + void stop(); +}; + +class ConsoleClient : public Pimpl { +public: + /** + * Constructor. TODO: Update + * + * @param database database to use for queries; observer only, doesn't take ownership + * @param scheduler scheduler to use for any timers; observer only, doesn't take ownership + * @param signal_mgr signal manager to install handlers with; observer only, doesn't take ownership; can be left + * null for testing purposes (will prevent aborting with SIGINT) + */ + ConsoleClient(const filesystem::path& socket, Scheduler* scheduler, SignalManager* signal_mgr); + ~ConsoleClient(); /** * Schedule a single statement for execution, to then exit once it has * finished. Must be called before `start(). + * + * @param stmt statement to execute + * @param echo whether to echo the result to tty */ - void scheduleStatementWithTermination(std::string stmt); + void scheduleStatementWithTermination(std::string stmt, bool echo = true); - /** Starts a console thread. */ - void start(); + /** + * Starts a console server thread. + * + * @param run_repl whether to run the REPL loop in the client; only turn + * off for testing purposes + */ + void start(bool run_repl = true); - /** Stops the console thread. */ + /** Stops the console server thread. */ void stop(); }; diff --git a/src/main.cc b/src/main.cc index 75eaba1b..8a295dd4 100644 --- a/src/main.cc +++ b/src/main.cc @@ -10,11 +10,14 @@ #include "io/console.h" #include "io/zeek.h" #include "platform/platform.h" +#include "spdlog/common.h" #include "util/fmt.h" #include "util/helpers.h" +#include "util/socket.h" #include #include +#include #include #ifdef HAVE_DARWIN @@ -36,6 +39,44 @@ static int main(const std::vector& argv); using namespace zeek::agent; +static int remoteConsole(const std::vector& argv) { + if ( ! Socket::supportsIPC() ) { + logger()->error("fatal error: remote console not supported on this platform"); + return 1; + } + + options::default_log_level = spdlog::level::err; + + Configuration cfg; + auto rc = cfg.initFromArgv(argv); + if ( ! rc ) { + std::cerr << rc.error() << std::endl; + return 1; + } + + auto socket = cfg.options().socket; + if ( ! socket ) { + logger()->error("no socket specified"); + return 1; + } + + SignalManager signal_mgr({SIGINT}); + Scheduler scheduler; + + ConsoleClient client(*socket, &scheduler, &signal_mgr); + + if ( ! cfg.options().execute.empty() ) + client.scheduleStatementWithTermination(cfg.options().execute); + + client.start(); + + while ( scheduler.loop() ) { + // nothing to do + } + + return 0; +} + int main(int argc, char** argv) { // Start with a stateless pass over our command line options to get our // mode of operation. This isn't using any OS-specific functionality yet, @@ -56,23 +97,42 @@ int main(int argc, char** argv) { switch ( options.mode ) { case options::Mode::Standard: { + // Need to create this in main thread, which means before, on macOS, we + // branch over into the NetworkExtension. + // + // TODO: Don't remember why this can't be a unique_ptr. + signal_mgr = new SignalManager({SIGINT}); + #ifdef HAVE_DARWIN // Our network extension needs to take over the primary thread, so we move // our main logic into a new thread. Also note that the network extension // needs to start up as early as possible, in particular (it appears) // before we start using the configuration system. - auto _ = std::make_unique([argv_]() { zeek::agent::main(argv_); }); + auto _ = std::make_unique([argv_]() { + int rc = zeek::agent::main(argv_); + delete signal_mgr; + exit(rc); + }); + platform::darwin::enterNetworkExtensionMode(); // won't return cannot_be_reached(); #else // Can run inside main thread. - return zeek::agent::main(argv_); + auto rc = zeek::agent::main(argv_); + delete signal_mgr; + return rc; #endif } + case options::Mode::RemoteConsole: return remoteConsole(argv_); + case options::Mode::Test: { #ifndef DOCTEST_CONFIG_DISABLE - if ( ! options.log_level ) { + if ( auto level = options.log_level ) { + options::default_log_level = *level; + logger()->set_level(*level); + } + else { options::default_log_level = options::LogLevel::off; logger()->set_level(options::LogLevel::off); } @@ -97,7 +157,7 @@ int main(int argc, char** argv) { // needs to be the network extension. int zeek::agent::main(const std::vector& argv) { logger()->info("Zeek Agent {} starting up", VersionLong); - atexit(log_termination); + (void)atexit(log_termination); auto _ = ScopeGuard([]() { platform::done(); }); @@ -117,21 +177,37 @@ int zeek::agent::main(const std::vector& argv) { platform::init(&cfg); Scheduler scheduler; - signal_mgr = new SignalManager({SIGINT}); sigint = new signal::Handler(signal_mgr, SIGINT, [&]() { scheduler.terminate(); }); Database db(&cfg, &scheduler); for ( const auto& t : Database::registeredTables() ) db.addTable(t.second.get()); - std::unique_ptr console; - if ( cfg.options().interactive || ! cfg.options().execute.empty() ) { - console = std::make_unique(&db, &scheduler, signal_mgr); + std::unique_ptr server; + std::unique_ptr client; + +#ifdef HAVE_WINDOWS + filesystem::path socket = "/zeek-agent"; // dummy name used just internally +#else + filesystem::path socket; + if ( auto s = cfg.options().socket ) + socket = *s; + +#endif + if ( ! socket.empty() ) { + server = std::make_unique(socket, &db, &scheduler); + + if ( cfg.options().interactive || ! cfg.options().execute.empty() ) + client = std::make_unique(socket, &scheduler, signal_mgr); + + server->start(); - if ( ! cfg.options().execute.empty() ) - console->scheduleStatementWithTermination(cfg.options().execute); + if ( client ) { + client->start(); - console->start(); + if ( ! cfg.options().execute.empty() ) + client->scheduleStatementWithTermination(cfg.options().execute); + } } std::unique_ptr zeek; @@ -153,20 +229,14 @@ int zeek::agent::main(const std::vector& argv) { platform::done(); delete sigint; - delete signal_mgr; - return 0; } catch ( const FatalError& e ) { logger()->error("fatal error: {}", e.what()); delete sigint; - delete signal_mgr; - return 1; } catch ( const InternalError& e ) { logger()->error("internal error: {}", e.what()); delete sigint; - delete signal_mgr; - return 1; } } diff --git a/src/platform/darwin/os-log-sink.h b/src/platform/darwin/os-log-sink.h index 12ea7242..5521bbd4 100644 --- a/src/platform/darwin/os-log-sink.h +++ b/src/platform/darwin/os-log-sink.h @@ -10,7 +10,7 @@ namespace zeek::agent::platform::darwin { /** Custom spdlog sink writing to OSLog. */ -class OSLogSink : public spdlog::sinks::base_sink { +class OSLogSink final : public spdlog::sinks::base_sink { public: OSLogSink(); ~OSLogSink() override; diff --git a/src/platform/linux/platform.cc b/src/platform/linux/platform.cc index 03c5ebac..698ee456 100644 --- a/src/platform/linux/platform.cc +++ b/src/platform/linux/platform.cc @@ -39,7 +39,7 @@ Result platform::setenv(const char* name, const char* value, int overwr std::optional platform::configurationFile() { // TODO: These paths aren't necessarily right yet. filesystem::path exec = PathFind::FindExecutable(); - return exec / "../etc" / "zeek-agent.conf"; + return filesystem::weakly_canonical(exec.parent_path() / "../etc" / "zeek-agent.conf"); } std::optional platform::dataDirectory() { diff --git a/src/platform/windows/platform.cc b/src/platform/windows/platform.cc index 7f559c53..11b87b5e 100644 --- a/src/platform/windows/platform.cc +++ b/src/platform/windows/platform.cc @@ -272,7 +272,10 @@ void WMIManager::GetUserData(const std::wstring& key, bool system_accounts, std: info.sid = narrowWstring(var.bstrVal); VariantClear(&var); - std::wstring path_query = frmt(L"SELECT LocalPath from Win32_UserProfile WHERE SID = \"{}\"", var.bstrVal); + // TODO: Using frmt() here leads a compiler error "caused by a read of + // a variable outside its lifetime". Not sure why. + std::wstring path_query = + fmt::format(L"SELECT LocalPath from Win32_UserProfile WHERE SID = \"{}\"", var.bstrVal); if ( auto user_enum = GetQueryEnumerator(path_query) ) { IWbemClassObjectPtr user_obj = nullptr; diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index f11f9d7b..58b31094 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -5,4 +5,11 @@ target_sources(zeek-agent ascii-table.cc helpers.cc result.cc + socket.cc ) + +if ( HAVE_POSIX ) + target_sources(zeek-agent PRIVATE socket.posix.cc) +else () + target_sources(zeek-agent PRIVATE socket.no-ipc.cc) +endif () diff --git a/src/util/ascii-table.cc b/src/util/ascii-table.cc index b6a6ed68..896c6f7a 100644 --- a/src/util/ascii-table.cc +++ b/src/util/ascii-table.cc @@ -52,7 +52,7 @@ void AsciiTable::printRow(std::ostream& out, const std::vector& row out << fill_left << value << fill_right; } - out << sep << '\n'; + out << '\n' << std::flush; } void AsciiTable::print(std::ostream& out, bool include_header) { diff --git a/src/util/filesystem.h b/src/util/filesystem.h index 203981c7..711fe765 100644 --- a/src/util/filesystem.h +++ b/src/util/filesystem.h @@ -2,7 +2,24 @@ #pragma once +#include "autogen/config.h" + +#include + #include /** Type alias. */ namespace filesystem = ghc::filesystem; + +namespace zeek::agent { + +#ifdef HAVE_WINDOWS +namespace platform::windows { +std::string narrowWstring(const std::wstring& wstr); // provided by platform.cc +} // namespace platform::windows +inline std::string path_to_string(const filesystem::path& p) { return platform::windows::narrowWstring(p.native()); } +#else +inline std::string path_to_string(const filesystem::path& p) { return p.native(); } +#endif + +} // namespace zeek::agent diff --git a/src/util/fmt.h b/src/util/fmt.h index 0811d727..a16c5dd5 100644 --- a/src/util/fmt.h +++ b/src/util/fmt.h @@ -2,6 +2,8 @@ #pragma once +#include "util/filesystem.h" + #include #include #include @@ -30,6 +32,12 @@ std::string to_string(const T& t) { return t.str(); } +/** Renders class instances through their `str()` method. */ +template<> +inline std::string to_string(const filesystem::path& t) { + return zeek::agent::path_to_string(t); +} + /** Fallback for strings. */ inline std::string to_string(const std::string& s) { return s; } @@ -41,3 +49,17 @@ struct fmt::formatter : fmt::formatter { return fmt::format_to(ctx.out(), "{}", json.dump()); } }; + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const filesystem::path& p, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), "{}", zeek::agent::path_to_string(p)); + } +}; + +template<> +struct fmt::formatter : fmt::formatter { + auto format(const wchar_t& c, format_context& ctx) const -> decltype(ctx.out()) { + return fmt::format_to(ctx.out(), L"{}", c); + } +}; diff --git a/src/util/helpers.cc b/src/util/helpers.cc index 0b4181cb..2fdedc4a 100644 --- a/src/util/helpers.cc +++ b/src/util/helpers.cc @@ -79,6 +79,21 @@ std::pair zeek::agent::rsplit1(std::string s, const st return std::make_pair("", std::move(s)); } +std::string zeek::agent::replace(const std::string& s, const std::string& o, const std::string& n) { + if ( o.empty() ) + return s; + + auto x = s; + + size_t i = 0; + while ( (i = x.find(o, i)) != std::string::npos ) { + x.replace(i, o.length(), n); + i += n.length(); + } + + return x; +} + static std::string base62_encode(uint64_t i) { static const char* alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; @@ -342,4 +357,9 @@ TEST_SUITE("Helpers") { CHECK(endsWith("abcd", "cd")); CHECK(! endsWith("abcd", "ab")); } + + TEST_CASE("replace") { + CHECK_EQ(replace("abcd", "ab", "xy"), "xycd"); + CHECK_EQ(replace("abcd", "QWQW", "xy"), "abcd"); + } } diff --git a/src/util/helpers.h b/src/util/helpers.h index d0863ff3..5db3d657 100644 --- a/src/util/helpers.h +++ b/src/util/helpers.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -114,6 +115,7 @@ inline std::string to_string(Interval t) { return std::string(b.str()); } + /** Aborts with an internal error saying we should not be where we are. */ #ifdef HAVE_MSVC __declspec(noreturn) extern void cannot_be_reached(); @@ -325,7 +327,7 @@ extern std::pair rsplit1(std::string s, const std::str * * \note This function is not UTF8-aware. */ -std::string replace(std::string s, std::string o, std::string n); +extern std::string replace(const std::string& s, const std::string& o, const std::string& n); /** * Returns true if a string begins with another. diff --git a/src/util/socket.cc b/src/util/socket.cc new file mode 100644 index 00000000..2b51f25e --- /dev/null +++ b/src/util/socket.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2021 by the Zeek Project. See LICENSE for details. +// +// Socket implementation for POSIX systems. + +#include "socket.h" + +#include "util/fmt.h" +#include "util/testing.h" + +using namespace zeek::agent; + +socket::SocketBuffer& socket::SocketBuffer::operator=(const SocketBuffer& other) { + if ( this == &other ) + return *this; + + _socket = other._socket; + _remote = other._remote; + return *this; +} + +int socket::SocketBuffer::sync() { + if ( auto rc = _socket->write(str(), *_remote); ! rc ) { + logger()->debug("failed to send message to socket: {}", rc.error()); + _remote->setError(rc.error()); + } + + str(""); + return 0; +} + +socket::Remote& socket::Remote::operator=(const Remote& other) noexcept { + if ( this == &other ) + return *this; + + _dst = other._dst; + _sbuf = SocketBuffer(other._sbuf._socket, this); + _sout = std::make_unique(&_sbuf); + return *this; +} + +TEST_SUITE("socket") { + TEST_CASE("read-and-write") { + auto path1 = filesystem::path(frmt("/tmp/zeek-agent-test-socket.{}.1", getpid())); + auto path2 = filesystem::path(frmt("/tmp/zeek-agent-test-socket.{}.2", getpid())); + + Socket socket1; + REQUIRE(! socket1); + REQUIRE(socket1.bind(path1)); + REQUIRE(socket1); + + Socket socket2; + REQUIRE(! socket2); + REQUIRE(socket2.bind(path2)); + REQUIRE(socket2); + + socket1.write("Hello, Socket 2!", {&socket1, path2}); + + auto result = socket2.read(); + REQUIRE(result); + REQUIRE(*result); + auto [data_1, remote_sender_1] = **result; + CHECK_EQ(data_1, "Hello, Socket 2!"); + + remote_sender_1 << "Hello, Socket 1!" << std::flush; + + result = socket1.read(); + REQUIRE(result); + REQUIRE(*result); + auto [data_2, remote_sender_2] = **result; + CHECK_EQ(data_2, "Hello, Socket 1!"); + + CHECK(remote_sender_1); + CHECK(remote_sender_2); + CHECK(socket1); + CHECK(socket2); + } + + TEST_CASE("unknown-remote") { + auto path = filesystem::path(frmt("/tmp/zeek-agent-test-socket.{}", getpid())); + + Socket socket; + REQUIRE(socket.bind(path)); + + socket::Remote remote(&socket, filesystem::path("/DOES-NOT-EXIST")); + remote << "xyz" << std::flush; + + CHECK(! remote); + CHECK(remote.error()); + } +} diff --git a/src/util/socket.h b/src/util/socket.h new file mode 100644 index 00000000..91afba6e --- /dev/null +++ b/src/util/socket.h @@ -0,0 +1,189 @@ +// Copyright (c) 2021 by the Zeek Project. See LICENSE for details. + +#pragma once + +#include "core/logger.h" +#include "util/helpers.h" +#include "util/pimpl.h" + +#include +#include +#include +#include + +namespace zeek::agent { + +class Socket; + +namespace socket { + +class Remote; + +/*** + * Opaque address handle identifying a remote socket endpoint. It's derived + * internally from the socket paths. Externally, + * there's no further semantics associated with the content of this string. + */ +using Address = std::string; + +/** + * Private helper class for `Remote` implementing a `stringbuf` variant that + * sends output to a socket. + **/ +class SocketBuffer : public std::stringbuf { +private: + friend class Remote; + + SocketBuffer(Socket* socket = nullptr, Remote* remote = nullptr) : _socket(socket), _remote(remote) {} + SocketBuffer(const SocketBuffer& other) : _socket(other._socket), _remote(other._remote) {} + + int sync() override; + + SocketBuffer& operator=(const SocketBuffer& other); + + Socket* _socket = nullptr; + Remote* _remote = {}; +}; + +/** + * Remote endpoint of a socket, accepting `<<` stream operations for sending + * output to that destination. + */ +class Remote { +public: + /** + * Constructor. + * + * @param local local socket the remote endpoint is associated with + * @param dst file system path identifying remote endpoint + */ + Remote(Socket* local, const filesystem::path& dst) + : _dst(pathToDestination(dst)), _sbuf(local, this), _sout(std::make_unique(&_sbuf)) {} + + /** + * Constructor. + * + * @param local local socket the remote endpoint is associated with + * @param dst opaque handle identifying the remote endpoint, as returned by `read()` or `destination()`. + */ + Remote(Socket* local = nullptr, Address dst = {}) + : _dst(std::move(dst)), _sbuf(local, this), _sout(std::make_unique(&_sbuf)) {} + + /** Copy constructor. */ + Remote(const Remote& other) + : _dst(other._dst), _sbuf(other._sbuf._socket, this), _sout(std::make_unique(&_sbuf)) {} + + /** Returns an opaque handle identifying the remote endpoint. */ + Address destination() const { return _dst; } + + /** Returns `ostream` sending output to the remote endpoint. */ + std::ostream& stream() { return *_sout; } + + /** Returns any error that a previous I/O operation has triggered. */ + const auto& error() const { return _error; } + + /*o Implicit conversion to an `ostream` sending output out to the socket. */ + operator std::ostream&() { return stream(); } + + /** Returns true if no error has been recorded by a previous I/O operation. */ + explicit operator bool() const { return ! _error.has_value(); } + + Remote& operator=(const Remote& other) noexcept; + bool operator==(const Remote& other) const { return _dst == other._dst; } + + /** Wrapper to make the remote endpoint compatible with `ostream-`style `<<` output. */ + template + Remote& operator<<(const T& t) { + stream() << t; + return *this; + } + + using _manip_type = std::ostream&(std::ostream&); + + /** Wrapper to make the remote endpoint compatible with `ostream-`style `<<` output. */ + Remote& operator<<(_manip_type m) { + stream() << m; + return *this; + } + +protected: + friend class SocketBuffer; + + void setError(const result::Error& err) { + _sout->setstate(std::ios_base::failbit); + _error = err; + } + +private: + // Convert a file system path into an opaque handle identifying the remote endpoint. + Address pathToDestination(const filesystem::path& path); + + Address _dst = {}; // opaque handle of remote endpoint + socket::SocketBuffer _sbuf; // stream buffer bound to the remote endpoint + std::unique_ptr _sout; // `ostream` using `sbuf` as its buffer + std::optional _error; // error state +}; + +} // namespace socket + +/** + * Provides a socket for IPC. + * + * This is a helper class that provides IPC functionality for our interactive + * console. Behind the scenes, we may implement sockets differently per + * platform. On POSIX systems, we use Unix datagramm sockets. Windows is not + * currently implemented. + */ +class Socket : public Pimpl { +public: + /** Constructor. */ + Socket(); + + /** Destructor. */ + ~Socket(); + + /** + * Binds the socket to a local path, setting it up for communication. + * + * @param path a local file system path identifying the socket; this is + * what remote endpoints will use to address the socket. + */ + Result bind(const filesystem::path& path); + + /** Result type of `read()`. */ + using ReadResult = std::optional>; + + /** + * Reads one message from the socket. If no input is currently available, + * this will block briefly and then return with an unset optional. + * + * @returns the optional will be either set to the pair of received message + * and its remote sender, or remain unset if no message is currently + * pending; if an error occurs, the result will reflect it + */ + Result read(); + + /** + * Sends one message to the given destination. + * + * @param data message to send; it will be send atomically, i.e., the + * receiver will always receive the whole message from a single call to + * `read()` + * @returns success if the message was sent, or a corresponding error otherwise + */ + Result write(const std::string& data, const socket::Remote& dst); + + /** Returns true if the socket is open and ready for communication. */ + bool isActive() const; + + /** Returns the result of `isActive() && ! error()`. */ + explicit operator bool() const { return isActive(); } + + /** + * Returns true if the socket implementation supports true IPC between + * independent processes. + */ + static bool supportsIPC(); +}; + +} // namespace zeek::agent diff --git a/src/util/socket.no-ipc.cc b/src/util/socket.no-ipc.cc new file mode 100644 index 00000000..ee3216e8 --- /dev/null +++ b/src/util/socket.no-ipc.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2021 by the Zeek Project. See LICENSE for details. +// +// Limited Socket implementation supporting only communication within the same +// process. This is for platforms where we have not implemented IPC support +// yet. + +#include "socket.h" + +#include "core/logger.h" +#include "platform/platform.h" +#include "util/fmt.h" +#include "util/helpers.h" + +#include +#include +#include + +using namespace zeek::agent; + +socket::Address socket::Remote::pathToDestination(const filesystem::path& path) { return to_string(path); } + +// Global map of messages queued for each path/address. +using Message = std::pair; +std::map> messages; // messages queued for each path + +// Lock protecting access to messages. +std::mutex mutex; + +template<> +struct Pimpl::Implementation { + // Binds the socket to a local path, setting it up for communication. + Result bind(const filesystem::path& path); + + // Reads one message from the socket. If no input is currently available, + Result read(); + + // Sends one message to the currently active destination. This will fail + Result write(const std::string& data, const socket::Remote& dst); + + Socket* _socket = nullptr; // socket that this implementation belongs to + filesystem::path _path; // path the socket is bound to + socket::Address _idx; // map into messages +}; + +Result Socket::Implementation::bind(const filesystem::path& path) { + const std::scoped_lock lock(mutex); + + _path = path; + _idx = to_string(_path); + messages[_idx] = {}; + return Nothing(); +} + +Result Socket::Implementation::read() { + const std::scoped_lock lock(mutex); + + auto i = messages.find(_idx); + if ( i == messages.end() ) + return result::Error("socket not bound", to_string(_path)); + + if ( i->second.empty() ) + return {std::nullopt}; + + auto msg = i->second.front(); + i->second.pop_front(); + + return std::make_optional(std::make_pair(msg.first, socket::Remote(_socket, msg.second))); +} + +Result Socket::Implementation::write(const std::string& data, const socket::Remote& dst) { + const std::scoped_lock lock(mutex); + + auto i = messages.find(dst.destination()); + if ( i == messages.end() ) + return result::Error("socket not bound", to_string(dst.destination())); + + i->second.emplace_back(data, _idx); + return Nothing(); +} + +Socket::Socket() { pimpl()->_socket = this; } + +Socket::~Socket() {} + +bool Socket::isActive() const { return ! pimpl()->_path.empty(); }; + +Result Socket::bind(const filesystem::path& path) { return pimpl()->bind(path); } + +Result Socket::read() { return pimpl()->read(); } + +Result Socket::write(const std::string& data, const socket::Remote& dst) { return pimpl()->write(data, dst); } + +bool Socket::supportsIPC() { return false; } diff --git a/src/util/socket.posix.cc b/src/util/socket.posix.cc new file mode 100644 index 00000000..548d67ef --- /dev/null +++ b/src/util/socket.posix.cc @@ -0,0 +1,192 @@ +// Copyright (c) 2021 by the Zeek Project. See LICENSE for details. +// +// Socket implementation for POSIX systems. + +#include "socket.h" + +#include "core/logger.h" +#include "util/helpers.h" + +#include +#include + +using namespace zeek::agent; + +static const auto SocketBufferSize = 32768U; + +// Converts an opaque address handle into a sockaddr_un. +static struct sockaddr_un dst2sock(const socket::Address& dst) { + struct sockaddr_un sock; + memcpy(&sock, dst.data(), sizeof(sock)); + return sock; +} + +// Converts a sockaddr_un into an opaque address handle. +static socket::Address sock2dst(const struct sockaddr_un& dst) { + return {reinterpret_cast(&dst), sizeof(dst)}; +} + +socket::Address socket::Remote::pathToDestination(const filesystem::path& path) { + struct sockaddr_un dst; + + if ( strlen(path.c_str()) >= sizeof(dst.sun_path) ) + throw FatalError(frmt("socket path too long: {}", path.native())); + + bzero(&dst, sizeof(dst)); + dst.sun_family = AF_UNIX; + strncpy(dst.sun_path, path.c_str(), sizeof(dst.sun_path) - 1); + dst.sun_path[sizeof(dst.sun_path) - 1] = '\0'; + return sock2dst(dst); +} + +template<> +struct Pimpl::Implementation { + // One-time initialization. + void init(); + + // One-time initialization. + void done(); + + // Binds the socket to a local path, setting it up for communication. + Result bind(const filesystem::path& path); + + // Reads one message from the socket. If no input is currently available, + Result read(); + + // Sends one message to the currently active destination. This will fail + Result write(const std::string& data, const socket::Remote& dst); + + Socket* _socket = nullptr; // socket that this implementation belongs to + int _fd = -1; // socket's fd + filesystem::path _path; // path the socket is bound to +}; + +void Socket::Implementation::init() {} + +void Socket::Implementation::done() { + if ( _fd >= 0 ) + close(_fd); + + if ( ! _path.empty() ) + unlink(_path.c_str()); +} + +Result Socket::Implementation::bind(const filesystem::path& path) { + if ( _fd >= 0 ) + return result::Error("socket already bound"); + + int flags; + struct sockaddr_un local; + + if ( strlen(path.c_str()) >= sizeof(local.sun_path) ) + return result::Error(frmt("socket path too long: {}", path.native())); + + auto fd = ::socket(AF_UNIX, SOCK_DGRAM, 0); + if ( fd < 0 ) + return result::Error(frmt("cannot create socket: {}", strerror(errno))); + + ScopeGuard _([&]() { + if ( fd >= 0 ) + close(fd); + }); + + const int bufsize = SocketBufferSize; + if ( setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) < 0 ) + logger()->warn("cannot set socket receive buffer size: {}", strerror(errno)); + + if ( setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) < 0 ) + logger()->warn("cannot set socket send buffer size: {}", strerror(errno)); + + // Let operations time out so that our I/O methods don't block and the + // caller can also check for termination. + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 50000; + if ( setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0 ) + return result::Error(frmt("cannot set socket timeout: {}", strerror(errno))); + + bzero(&local, sizeof(local)); + local.sun_family = AF_UNIX; + strncpy(local.sun_path, path.c_str(), sizeof(local.sun_path) - 1); + local.sun_path[sizeof(local.sun_path) - 1] = '\0'; + unlink(path.c_str()); + + { + // Grant only the current user the permission to access the socket. + auto old_umask = umask(0077); + ScopeGuard _([&]() { umask(old_umask); }); + + if ( ::bind(fd, (struct sockaddr*)&local, sizeof(local)) < 0 ) + return result::Error(frmt("cannot bind to socket: {}", strerror(errno))); + } + + _path = path; + _fd = fd; + fd = -1; + + return Nothing(); +} + +Result Socket::Implementation::read() { + if ( _fd < 0 ) + return result::Error("socket not open"); + + struct sockaddr_un sender; + socklen_t sender_size = sizeof(sender); + + char buffer[SocketBufferSize]; + auto len = recvfrom(_fd, buffer, sizeof(buffer), 0, (struct sockaddr*)&sender, &sender_size); + if ( len < 0 ) { + if ( errno != EAGAIN || errno != EWOULDBLOCK ) + return result::Error(strerror(errno)); + + return {std::nullopt}; + } + + return std::make_optional(std::make_pair(std::string(buffer, len), socket::Remote(_socket, sock2dst(sender)))); +} + +Result Socket::Implementation::write(const std::string& data, const socket::Remote& dst) { + if ( _fd < 0 ) + return result::Error("socket not open"); + + if ( data.empty() ) + return Nothing(); + + int attempts = 0; + while ( attempts++ < 50 ) { + struct sockaddr_un sockaddr = dst2sock(dst.destination()); + auto len = sendto(_fd, data.data(), data.size(), 0, (struct sockaddr*)&sockaddr, sizeof(sockaddr)); + if ( len >= 0 ) + return Nothing(); + + if ( errno == ENOBUFS ) { + usleep(100); // give client a chance to catch up + continue; + } + + if ( errno == EAGAIN || errno == EWOULDBLOCK ) + return Nothing(); // time out + + break; + } + + return result::Error(strerror(errno)); +} + +Socket::Socket() { + pimpl()->_socket = this; + pimpl()->init(); +} + +Socket::~Socket() { pimpl()->done(); } + +bool Socket::isActive() const { return pimpl()->_fd >= 0; }; + +Result Socket::bind(const filesystem::path& path) { return pimpl()->bind(path); } + +Result Socket::read() { return pimpl()->read(); } + +Result Socket::write(const std::string& data, const socket::Remote& dst) { return pimpl()->write(data, dst); } + +bool Socket::supportsIPC() { return true; }