diff --git a/src/stirling/source_connectors/socket_tracer/BUILD.bazel b/src/stirling/source_connectors/socket_tracer/BUILD.bazel index 4dc5e9308ac..f738dd92d8c 100644 --- a/src/stirling/source_connectors/socket_tracer/BUILD.bazel +++ b/src/stirling/source_connectors/socket_tracer/BUILD.bazel @@ -204,6 +204,7 @@ pl_cc_bpf_test( ":cc_library", "//src/common/exec:cc_library", "//src/stirling/source_connectors/socket_tracer/testing:cc_library", + "//src/stirling/source_connectors/socket_tracer/testing/container_images:curl_container", "//src/stirling/testing:cc_library", "//src/stirling/testing/demo_apps/go_http:go_http_fixture", ], diff --git a/src/stirling/source_connectors/socket_tracer/http_trace_bpf_test.cc b/src/stirling/source_connectors/socket_tracer/http_trace_bpf_test.cc index 528fc398a74..9d366c20da1 100644 --- a/src/stirling/source_connectors/socket_tracer/http_trace_bpf_test.cc +++ b/src/stirling/source_connectors/socket_tracer/http_trace_bpf_test.cc @@ -21,6 +21,7 @@ #include "src/common/testing/testing.h" #include "src/stirling/core/data_table.h" #include "src/stirling/source_connectors/socket_tracer/socket_trace_connector.h" +#include "src/stirling/source_connectors/socket_tracer/testing/container_images/curl_container.h" #include "src/stirling/source_connectors/socket_tracer/testing/socket_trace_bpf_test_fixture.h" #include "src/stirling/testing/demo_apps/go_http/go_http_fixture.h" @@ -150,6 +151,82 @@ TEST_F(GoHTTPTraceTest, LargePostMessage) { 131096); } +class CurlHTTPTraceTest : public SocketTraceBPFTestFixture { + protected: + CurlHTTPTraceTest() : SocketTraceBPFTestFixture() {} + + void SetUp() override { + SocketTraceBPFTestFixture::SetUp(); + go_http_fixture_.LaunchServer(); + } + + void TearDown() override { + SocketTraceBPFTestFixture::TearDown(); + go_http_fixture_.ShutDown(); + } + + testing::GoHTTPFixture go_http_fixture_; + + DataTable data_table_{/*id*/ 0, kHTTPTable}; +}; + +TEST_F(CurlHTTPTraceTest, XFormURLEncodedRequest) { + StartTransferDataThread(); + + // Uncomment to enable tracing: + // FLAGS_stirling_conn_trace_pid = go_http_fixture_.server_pid(); + + ::px::stirling::testing::CurlContainer client; + auto payload = R"( +{ + "commands": [ + { + "server": "api.use-case.svc.cluster.local:5011", + "action": "req2", + "telemetry": "uninstrumented", + "params": {} + }, + { + "server": "repo.use-case.svc.cluster.local:5012", + "action": "add_user", + "telemetry": "uninstrumented", + "params": { + "name": "John Doe", + "email": "fd2@doe.com" + } + } + ] +} +)"; + + auto uri = absl::Substitute("http://127.0.0.1:$0/post", go_http_fixture_.server_port()); + auto body = absl::Substitute("'action=$0'", payload); + ASSERT_OK(client.Run(std::chrono::seconds{60}, {"--network=host"}, + {"-XPOST", "--data-urlencode", body, uri})); + client.Wait(); + + StopTransferDataThread(); + + std::vector tablets = ConsumeRecords(kHTTPTableNum); + ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets); + + // We do expect to trace the server. + const std::vector target_record_indices = + testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, go_http_fixture_.server_pid()); + ASSERT_THAT(target_record_indices, SizeIs(1)); + + const size_t target_record_idx = target_record_indices.front(); + + EXPECT_THAT( + std::string(record_batch[kHTTPReqHeadersIdx]->Get(target_record_idx)), + AllOf( + HasSubstr(R"("Content-Type":"application/x-www-form-urlencoded")"), + HasSubstr(absl::Substitute(R"(Host":"127.0.0.1:$0")", go_http_fixture_.server_port())))); + + EXPECT_THAT(record_batch[kHTTPReqBodyIdx]->Get(target_record_idx), + StrEq(body)); +} + struct TraceRoleTestParam { endpoint_role_t role; size_t client_records_count; diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.cc b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.cc index 54ff841b5a9..d6fe4c531c8 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.cc @@ -43,7 +43,7 @@ namespace stirling { namespace protocols { namespace http { -void PreProcessMessage(Message* message) { +void PreProcessRespMessage(Message* message) { // Parse the flags on the first time only. static const HTTPHeaderFilter kHTTPResponseHeaderFilter = ParseHTTPHeaderFilters(FLAGS_http_response_header_filters); @@ -75,6 +75,19 @@ void PreProcessMessage(Message* message) { } } +void PreProcessReqMessage(Message* message) { + // Unlike responses, leave the body intact for messages that don't specify a Content-Type + auto content_type_iter = message->headers.find(http::kContentType); + if (content_type_iter == message->headers.end()) { + return; + } + + if (message->type == message_type_t::kRequest && + content_type_iter->second == "application/x-www-form-urlencoded") { + message->body = HTTPUrlDecode(message->body); + } +} + } // namespace http } // namespace protocols } // namespace stirling diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.h b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.h index 856984e232f..e8a4d493c20 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.h @@ -45,7 +45,8 @@ namespace http { RecordsWithErrorCount ProcessMessages(std::deque* req_messages, std::deque* resp_messages); -void PreProcessMessage(Message* message); +void PreProcessRespMessage(Message* message); +void PreProcessReqMessage(Message* message); } // namespace http diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher_test.cc index 23ee2741519..b9ad28a7b33 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher_test.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/stitcher_test.cc @@ -19,6 +19,7 @@ #include #include +#include "src/common/json/json.h" #include "src/stirling/source_connectors/socket_tracer/protocols/http/stitcher.h" namespace px { @@ -26,11 +27,12 @@ namespace stirling { namespace protocols { namespace http { +using ::px::utils::ToJSONString; using ::testing::Contains; using ::testing::Pair; using ::testing::StrEq; -TEST(PreProcessRecordTest, GzipCompressedContentIsDecompressed) { +TEST(PreProcessRespRecordTest, GzipCompressedContentIsDecompressed) { Message message; message.type = message_type_t::kResponse; message.headers.insert({kContentEncoding, "gzip"}); @@ -41,27 +43,65 @@ TEST(PreProcessRecordTest, GzipCompressedContentIsDecompressed) { 0x85, 0x92, 0xd4, 0xe2, 0x12, 0x2e, 0x00, 0x8c, 0x2d, 0xc0, 0xfa, 0x0f, 0x00, 0x00, 0x00}; message.body.assign(reinterpret_cast(compressed_bytes), sizeof(compressed_bytes)); - PreProcessMessage(&message); + PreProcessRespMessage(&message); EXPECT_EQ("This is a test\n", message.body); } -TEST(PreProcessRecordTest, ContentHeaderIsNotAdded) { +// Determines if the character should be percent encoded accoridng to the URL +// encoding spec https://en.wikipedia.org/wiki/Percent-encoding +bool IsUnreservedChar(unsigned char c) { + return (c >= 0x30 && c <= 0x39) || (c >= 0x41 && c <= 0x5A) || (c >= 0x61 && c <= 0x7A) || + c == '-' || c == '_' || c == '.' || c == '~'; +} + +constexpr unsigned char hex[] = "0123456789ABCDEF"; + +std::string HTTPUrlEncode(std::string_view input) { + std::string encoded = ""; + for (auto c : input) { + if (IsUnreservedChar(c)) { + encoded.push_back(c); + } else { + encoded.push_back('%'); + encoded.push_back(hex[c >> 4]); + encoded.push_back(hex[c & 0xf]); + } + } + return encoded; +} + +TEST(PreProcessRespRecordTest, ContentHeaderIsNotAdded) { Message message; message.type = message_type_t::kResponse; message.body = "test"; message.headers.insert({kContentType, "text"}); - PreProcessMessage(&message); + PreProcessRespMessage(&message); EXPECT_EQ("", message.body); EXPECT_THAT(message.headers, Contains(Pair(kContentType, "text"))); } +TEST(PreProcessReqRecordTest, FormUrlEncodedDataIsDecoded) { + std::map> payload = { + {"commands", {"nested1", "nested2"}}, + {"params", {"name", "email"}}, + }; + auto json_str = ToJSONString(payload); + Message message; + message.type = message_type_t::kRequest; + message.body = HTTPUrlEncode(json_str); + message.headers.insert({kContentType, "application/x-www-form-urlencoded"}); + PreProcessReqMessage(&message); + EXPECT_EQ(json_str, message.body); + EXPECT_THAT(message.headers, Contains(Pair(kContentType, "application/x-www-form-urlencoded"))); +} + // Tests that when body-size is 0, the message body won't be rewritten. -TEST(PreProcessRecordTest, ZeroSizedBodyNotRewritten) { +TEST(PreProcessRespRecordTest, ZeroSizedBodyNotRewritten) { Message message; message.type = message_type_t::kResponse; message.body_size = 0; EXPECT_THAT(message.body, StrEq("-")); - PreProcessMessage(&message); + PreProcessRespMessage(&message); EXPECT_THAT(message.body, StrEq("-")); } diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/utils.cc b/src/stirling/source_connectors/socket_tracer/protocols/http/utils.cc index f24e372532f..b2988416910 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/utils.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/utils.cc @@ -74,6 +74,45 @@ HTTPHeaderFilter ParseHTTPHeaderFilters(std::string_view filters) { return result; } +// Lookup table for fast conversion of hex character to decimal value +static constexpr unsigned char hextable[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, /* 0x30 - 0x3f */ + 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x40 - 0x4f */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x50 - 0x5f */ + 0, 10, 11, 12, 13, 14, 15 /* 0x60 - 0x66 */ +}; + +#define HEX_TO_DEC(x) hextable[x - '0'] + +// Checks to see if the char is a valid hex digit [A-Fa-f0-9] +// Used to check that HEX_TO_DEC is safe to apply for the given char +bool IsHexDigit(char c) { + return (c >= 0x30 && c <= 0x39) || (c >= 0x41 && c <= 0x46) || (c >= 0x61 && c <= 0x66); +} + +std::string HTTPUrlDecode(const std::string_view input) { + std::string output; + output.reserve(input.size()); + size_t pos = 0; + size_t end = input.size(); + while (pos < end) { + auto c = input.at(pos); + if (c == '+') { + output.push_back(' '); + pos += 1; + } else if (c == '%' && (end - pos >= 2) && IsHexDigit(input[pos + 1]) && + IsHexDigit(input[pos + 2])) { + output.push_back((unsigned char)(HEX_TO_DEC(input[pos + 1]) << 4) | + HEX_TO_DEC(input[pos + 2])); + pos += 3; + } else { + output.push_back(c); + pos += 1; + } + } + return output; +} + bool IsJSONContent(const Message& message) { auto content_type_iter = message.headers.find(kContentType); if (content_type_iter == message.headers.end()) { diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/utils.h b/src/stirling/source_connectors/socket_tracer/protocols/http/utils.h index 62005364173..4b34503a1c8 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/utils.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/utils.h @@ -59,6 +59,8 @@ HTTPHeaderFilter ParseHTTPHeaderFilters(std::string_view filters); */ bool MatchesHTTPHeaders(const HeadersMap& http_headers, const HTTPHeaderFilter& filter); +std::string HTTPUrlDecode(std::string_view input); + /** * Detects the content-type of an HTTP message. Currently only checks for JSON. */ diff --git a/src/stirling/source_connectors/socket_tracer/protocols/http/utils_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/http/utils_test.cc index 93b2578e1d7..d1abaa1b3c3 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/http/utils_test.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/http/utils_test.cc @@ -76,6 +76,29 @@ TEST(ParseHTTPHeaderFiltersAndMatchTest, FiltersAreAsExpectedAndMatchesWork) { } } +TEST(HTTPUrlDecode, Decode) { + std::string input = + "action=%7B%0A++%22commands%22%3A+%5B%0A++++%7B%0A++++++%22server%22%3A+%22api.use-case.svc." + "cluster.local%3A5011%22%2C%0A++++++%22action%22%3A+%22req2%22%2C%0A++++++%22telemetry"; + std::string expected = R"(action={ + "commands": [ + { + "server": "api.use-case.svc.cluster.local:5011", + "action": "req2", + "telemetry)"; + EXPECT_EQ(HTTPUrlDecode(input), expected); +} + +TEST(HTTPUrlDecode, DecodeTruncatedInput) { + // Use input that has its first % encoded hex digit removed + std::string truncated_first_digit = "action=%"; + EXPECT_EQ(HTTPUrlDecode(truncated_first_digit), truncated_first_digit); + + // Use input that has its second % encoded hex digit removed + std::string truncated_second_digit = "action=%7"; + EXPECT_EQ(HTTPUrlDecode(truncated_second_digit), truncated_second_digit); +} + } // namespace http } // namespace protocols } // namespace stirling diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc index 95de576055f..185bd0e8299 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc @@ -1254,7 +1254,8 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke // Currently decompresses gzip content, but could handle other transformations too. // Note that we do this after filtering to avoid burning CPU cycles unnecessarily. - protocols::http::PreProcessMessage(&resp_message); + protocols::http::PreProcessRespMessage(&resp_message); + protocols::http::PreProcessReqMessage(&req_message); md::UPID upid(ctx->GetASID(), conn_tracker.conn_id().upid.pid, conn_tracker.conn_id().upid.start_time_ticks);