Skip to content

Commit

Permalink
Populate client side trace's local address via tcp kprobes (#1989)
Browse files Browse the repository at this point in the history
Summary: Populate client side trace's local address via tcp kprobes

This change populates client side trace's `local_addr` and `local_port`
columns for the following use cases:
1. To provide more consistency for the protocol data tables. Having
columns that are empty make it difficult for end users to understand
what is being traced and make them less useful
2. To facilitate addressing a portion of the short lived process
problems (#1638)

For 2, the root of the issue is that `df.ctx["pod"]` syntax relies on
the
[px.upid_to_pod_name](https://docs.px.dev/reference/pxl/udf/upid_to_pod_name/)
function. If a PEM misses the short lived process during its metadata
update, this function fails to resolve the pod name. For client side
traces where the pod is making an outbound connection (non localhost),
the `local_addr` column provides an alternative pod name lookup for
short lived processes when the pod is long lived. This means the
following would be equivalent to the `df.ctx["pod"]` lookup:
`px.pod_id_to_pod_name(px.ip_to_pod_id(df.local_addr))`.

I intend to follow this PR with a compiler change that will make
`df.ctx["pod"]` try both methods should `px.upid_to_pod_name` fail to
resolve. This will allow the existing pxl scripts to display the
previously missed short lived processes.

**Alternatives**

Another approach I considered was expanding our use of the `sock_alloc`
kprobe. I used ftrace on a simple curl command to see what other options
could be used (`sudo trace-cmd record -F -p function_graph
http://google.com`). The `socket` syscall calls `sock_alloc`, which
would be another mechanism for accessing the `struct sock`. I decided
against this approach because I don't think its viable to assume that
the same thread/process that calls `socket` will be the one that does
the later syscalls (how our BPF maps are set up). It's common to have a
forking web server model, which means a different process/thread can
call `socket` than the ones that later read/write to it.

**Probe stability**

These probes appear to be stable from our oldest and newest supported
kernel. These functions exist in the
[tcp_prot](https://elixir.bootlin.com/linux/v4.14.336/source/net/ipv4/tcp_ipv4.c#L2422),
[tcpv6_prot](https://elixir.bootlin.com/linux/v4.14.336/source/net/ipv6/tcp_ipv6.c#L1941)
structs and I've seen that other projects and bcc tools use these
probes. This makes me believe that these functions have a pretty well
defined interface.

Relevant Issues: #1829, #1638

Type of change: /kind feature

Test Plan: New tests verify that ipv4 and ipv6 cases work
- [x] Ran `for i in $(seq 0 1000); do curl http://google.com/$i; sleep
2; done` within a pod and verified that `local_addr` is populated with
this change and `px.pod_id_to_pod_name(px.ip_to_pod_id(df.local_addr))`
works for pod name resolution.

- [x] Verified the above curl test results in traces without
`local_addr` without this change

![local-addr-testing](https://github.com/user-attachments/assets/344be022-97a0-4096-8af7-8de20d741e40)
- Tested on the following k8s offerings and machine images
- [x] GKE COS and Ubuntu
- [x] EKS Amazon Linux 2

Changelog Message: Populate socket tracer data table `local_addr` and
`local_port` column for client side traces.

---------

Signed-off-by: Dom Del Nano <[email protected]>
  • Loading branch information
ddelnano authored Sep 5, 2024
1 parent afffb8e commit 0c1fdd2
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 15 deletions.
81 changes: 67 additions & 14 deletions src/stirling/source_connectors/socket_tracer/bcc_bpf/socket_trace.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ BPF_HASH(active_accept_args_map, uint64_t, struct accept_args_t);
// Key is {tgid, pid}.
BPF_HASH(active_connect_args_map, uint64_t, struct connect_args_t);

// Map from thread to its sock* struct. This facilitates capturing
// the local address of a tcp socket during connect() syscalls.
// Key is {tgid, pid}.
BPF_HASH(tcp_connect_args_map, uint64_t, struct sock*);

// Map from thread to its ongoing write() syscall's input argument.
// Tracks write() call from entry -> exit.
// Key is {tgid, pid}.
Expand Down Expand Up @@ -345,19 +350,17 @@ static __inline void update_traffic_class(struct conn_info_t* conn_info,
* Perf submit functions
***********************************************************/

static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info,
const struct socket* socket) {
// Use BPF_PROBE_READ_KERNEL_VAR since BCC cannot insert them as expected.
struct sock* sk = NULL;
BPF_PROBE_READ_KERNEL_VAR(sk, &socket->sk);

struct sock_common* sk_common = &sk->__sk_common;
static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info, const struct sock* sk) {
const struct sock_common* sk_common = &sk->__sk_common;
uint16_t family = -1;
uint16_t lport = -1;
uint16_t rport = -1;

BPF_PROBE_READ_KERNEL_VAR(family, &sk_common->skc_family);
BPF_PROBE_READ_KERNEL_VAR(lport, &sk_common->skc_num);
// skc_num is stored in host byte order. The rest of our user space processing
// assumes network byte order so convert it here.
lport = htons(lport);
BPF_PROBE_READ_KERNEL_VAR(rport, &sk_common->skc_dport);

conn_info->laddr.sa.sa_family = family;
Expand All @@ -377,12 +380,12 @@ static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info,
}

static __inline void submit_new_conn(struct pt_regs* ctx, uint32_t tgid, int32_t fd,
const struct sockaddr* addr, const struct socket* socket,
const struct sockaddr* addr, const struct sock* sock,
enum endpoint_role_t role, enum source_function_t source_fn) {
struct conn_info_t conn_info = {};
init_conn_info(tgid, fd, &conn_info);
if (socket != NULL) {
read_sockaddr_kernel(&conn_info, socket);
if (sock != NULL) {
read_sockaddr_kernel(&conn_info, sock);
} else if (addr != NULL) {
conn_info.raddr = *((union sockaddr_t*)addr);
}
Expand Down Expand Up @@ -585,6 +588,52 @@ int conn_cleanup_uprobe(struct pt_regs* ctx) {
return 0;
}

// These probes are used to capture the *sock struct during client side tracing
// of connect() syscalls. This is necessary to capture the socket's local address,
// which is not accessible via the connect() and later syscalls.
//
// This function requires that the function being probed receives a struct sock* as its
// first argument and that the active_connect_args_map is populated when this probe fires.
// This means the function being probed must be part of the connect() syscall path or similar
// syscall path.
//
// Using the struct sock* for capturing a socket's local address only works for TCP sockets.
// The equivalent UDP functions (udp_v4_connect, udp_v6_connect and upd_sendmsg) always receive a
// sock struct with a 0.0.0.0 or ::1 local address. This is deemed acceptable since our local
// address population for server side tracing relies on accept/accept4, which only applies for TCP.
//
// int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
// static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
// int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int probe_entry_populate_active_connect_sock(struct pt_regs* ctx) {
uint64_t id = bpf_get_current_pid_tgid();

const struct connect_args_t* connect_args = active_connect_args_map.lookup(&id);
if (connect_args == NULL) {
return 0;
}
struct sock* sk = (struct sock*)PT_REGS_PARM1(ctx);
tcp_connect_args_map.update(&id, &sk);

return 0;
}

int probe_ret_populate_active_connect_sock(struct pt_regs* ctx) {
uint64_t id = bpf_get_current_pid_tgid();

struct sock** sk = tcp_connect_args_map.lookup(&id);
if (sk == NULL) {
return 0;
}
struct connect_args_t* connect_args = active_connect_args_map.lookup(&id);
if (connect_args != NULL) {
connect_args->connect_sock = *sk;
}

tcp_connect_args_map.delete(&id);
return 0;
}

/***********************************************************
* BPF syscall processing functions
***********************************************************/
Expand Down Expand Up @@ -629,7 +678,8 @@ static __inline void process_syscall_connect(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleClient, kSyscallConnect);
submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleClient,
kSyscallConnect);
}

static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id,
Expand All @@ -645,8 +695,11 @@ static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, ret_fd, args->addr, args->sock_alloc_socket, kRoleServer,
kSyscallAccept);
const struct sock* sk = NULL;
if (args->sock_alloc_socket != NULL) {
BPF_PROBE_READ_KERNEL_VAR(sk, &args->sock_alloc_socket->sk);
}
submit_new_conn(ctx, tgid, ret_fd, args->addr, sk, kRoleServer, kSyscallAccept);
}

// TODO(oazizi): This is badly broken (but better than before).
Expand Down Expand Up @@ -690,7 +743,7 @@ static __inline void process_implicit_conn(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleUnknown, source_fn);
submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleUnknown, source_fn);
}

static __inline bool should_send_data(uint32_t tgid, uint64_t conn_disabled_tsid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ struct socket_control_event_t {

struct connect_args_t {
const struct sockaddr* addr;
const struct sock* connect_sock;
int32_t fd;
};

Expand Down
152 changes: 152 additions & 0 deletions src/stirling/source_connectors/socket_tracer/socket_trace_bpf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@
#include "src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.hpp"
#include "src/stirling/source_connectors/socket_tracer/socket_trace_connector.h"
#include "src/stirling/source_connectors/socket_tracer/testing/client_server_system.h"
#include "src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h"
#include "src/stirling/source_connectors/socket_tracer/testing/socket_trace_bpf_test_fixture.h"
#include "src/stirling/testing/common.h"

namespace px {
namespace stirling {

using ::px::stirling::testing::FindRecordsMatchingPID;
using ::px::stirling::testing::GetLocalAddrs;
using ::px::stirling::testing::GetLocalPorts;
using ::px::stirling::testing::RecordBatchSizeIs;
using ::px::system::TCPSocket;
using ::px::system::UDPSocket;
Expand Down Expand Up @@ -747,6 +750,155 @@ TEST_F(NullRemoteAddrTest, IPv6Accept4WithNullRemoteAddr) {
EXPECT_EQ(records[kHTTPRemotePortIdx]->Get<types::Int64Value>(0), port);
}

using LocalAddrTest = testing::SocketTraceBPFTestFixture</* TClientSideTracing */ true>;

TEST_F(LocalAddrTest, IPv4ConnectPopulatesLocalAddr) {
StartTransferDataThread();

TCPSocket client;
TCPSocket server;

std::atomic<bool> server_ready = true;

std::thread server_thread([&server, &server_ready]() {
server.BindAndListen();
server_ready = true;
auto conn = server.Accept(/* populate_remote_addr */ true);

std::string data;

conn->Read(&data);
conn->Write(kHTTPRespMsg1);
});

// Wait for server thread to start listening.
while (!server_ready) {
}
// After server_ready, server.Accept() needs to enter the accepting state, before the client
// connection can succeed below. We don't have a simple and robust way to signal that from inside
// the server thread, so we just use sleep to avoid the race condition.
std::this_thread::sleep_for(std::chrono::seconds(1));

std::thread client_thread([&client, &server]() {
client.Connect(server);

std::string data;

client.Write(kHTTPReqMsg1);
client.Read(&data);
});

server_thread.join();
client_thread.join();

// Get the remote port seen by server from client's local port.
struct sockaddr_in client_sockaddr = {};
socklen_t client_sockaddr_len = sizeof(client_sockaddr);
struct sockaddr* client_sockaddr_ptr = reinterpret_cast<struct sockaddr*>(&client_sockaddr);
ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0);

// Close after getting the sockaddr from fd, otherwise getsockname() wont work.
client.Close();
server.Close();

StopTransferDataThread();

std::vector<TaggedRecordBatch> tablets = ConsumeRecords(kHTTPTableNum);
ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets);

std::vector<size_t> indices =
testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid());
ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices);

ASSERT_THAT(records, RecordBatchSizeIs(2));

// Make sure that the socket info resolution works.
ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv4AddrToString(client_sockaddr.sin_addr));
EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("127.0.0.1").Times(2));
EXPECT_EQ(remote_addr, "127.0.0.1");

bool found_port = false;
uint16_t port = ntohs(client_sockaddr.sin_port);
for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) {
if (lport == port) {
found_port = true;
break;
}
}
EXPECT_TRUE(found_port);
}

TEST_F(LocalAddrTest, IPv6ConnectPopulatesLocalAddr) {
StartTransferDataThread();

TCPSocket client(AF_INET6);
TCPSocket server(AF_INET6);

std::atomic<bool> server_ready = false;

std::thread server_thread([&server, &server_ready]() {
server.BindAndListen();
server_ready = true;
auto conn = server.Accept(/* populate_remote_addr */ false);

std::string data;

conn->Read(&data);
conn->Write(kHTTPRespMsg1);
});

while (!server_ready) {
}

std::thread client_thread([&client, &server]() {
client.Connect(server);

std::string data;

client.Write(kHTTPReqMsg1);
client.Read(&data);
});

server_thread.join();
client_thread.join();

// Get the remote port seen by server from client's local port.
struct sockaddr_in6 client_sockaddr = {};
socklen_t client_sockaddr_len = sizeof(client_sockaddr);
struct sockaddr* client_sockaddr_ptr = reinterpret_cast<struct sockaddr*>(&client_sockaddr);
ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0);

// Close after getting the sockaddr from fd, otherwise getsockname() wont work.
client.Close();
server.Close();

StopTransferDataThread();

std::vector<TaggedRecordBatch> tablets = ConsumeRecords(kHTTPTableNum);
ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets);

std::vector<size_t> indices =
testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid());
ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices);

ASSERT_THAT(records, RecordBatchSizeIs(2));

// Make sure that the socket info resolution works.
ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv6AddrToString(client_sockaddr.sin6_addr));
EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("::1").Times(2));
EXPECT_EQ(remote_addr, "::1");

bool found_port = false;
uint16_t port = ntohs(client_sockaddr.sin6_port);
for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) {
if (lport == port) {
found_port = true;
break;
}
}
EXPECT_TRUE(found_port);
}

// Run a UDP-based client-server system.
class UDPSocketTraceBPFTest : public SocketTraceBPFTest {
protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ DEFINE_bool(
stirling_debug_tls_sources, gflags::BoolFromEnv("PX_DEBUG_TLS_SOURCES", false),
"If true, stirling will add additional prometheus metrics regarding the traced tls sources");

DEFINE_uint32(stirling_bpf_loop_limit, 42,
DEFINE_uint32(stirling_bpf_loop_limit, 41,
"The maximum number of iovecs to capture for syscalls. "
"Set conservatively for older kernels by default to keep the instruction count below "
"BPF's limit for version 4 kernels (4096 per probe).");
Expand Down Expand Up @@ -342,6 +342,18 @@ const auto kProbeSpecs = MakeArray<bpf_tools::KProbeSpec>({
{"close", ProbeType::kReturn, "syscall__probe_ret_close"},
{"mmap", ProbeType::kEntry, "syscall__probe_entry_mmap"},
{"sock_alloc", ProbeType::kReturn, "probe_ret_sock_alloc", /*is_syscall*/ false},
{"tcp_v4_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v4_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v6_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v6_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_sendmsg", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_sendmsg", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
{"security_socket_sendmsg", ProbeType::kEntry, "probe_entry_socket_sendmsg",
/*is_syscall*/ false, /* is_optional */ false,
std::make_shared<bpf_tools::KProbeSpec>(bpf_tools::KProbeSpec{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ inline std::vector<bool> GetEncrypted(const types::ColumnWrapperRecordBatch& rb,
return encrypted;
}

inline std::vector<std::string> GetLocalAddrs(const types::ColumnWrapperRecordBatch& rb,
const int local_addr_idx,
const std::vector<size_t>& indices) {
std::vector<std::string> laddrs;
for (size_t idx : indices) {
laddrs.push_back(rb[local_addr_idx]->Get<types::StringValue>(idx));
}
return laddrs;
}

inline std::vector<int64_t> GetLocalPorts(const types::ColumnWrapperRecordBatch& rb,
const int local_port_idx,
const std::vector<size_t>& indices) {
std::vector<int64_t> ports;
for (size_t idx : indices) {
ports.push_back(rb[local_port_idx]->Get<types::Int64Value>(idx).val);
}
return ports;
}

inline std::vector<int64_t> GetRemotePorts(const types::ColumnWrapperRecordBatch& rb,
const std::vector<size_t>& indices) {
std::vector<int64_t> addrs;
Expand Down

0 comments on commit 0c1fdd2

Please sign in to comment.