Skip to content

Commit

Permalink
[Fix CQL Sticher 4/4] Move StreamID assignment to ParseFramesLoop (#1732
Browse files Browse the repository at this point in the history
)

Summary: Populates a map of streamIDs to deque of frames in
`ParseFramesLoop` instead of `ParseFrames`. This should provide a small
efficiency boost, as we won't have to loop over the frames twice. This
PR relies on #1761 due to the way timestamps are updated using
`ParseResult`.

Related issues: #1375

Type of change: /kind cleanup

Test Plan: Updated parsing tests to use new interface

Signed-off-by: Benjamin Kilimnik <[email protected]>
  • Loading branch information
benkilimnik authored Nov 9, 2023
1 parent a2ff394 commit 4926254
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 325 deletions.
6 changes: 4 additions & 2 deletions src/stirling/source_connectors/socket_tracer/data_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void DataStream::ProcessBytesToFrames(message_type_t type, TStateType* state) {

bool keep_processing = has_new_events_ || attempt_sync || conn_closed();

protocols::ParseResult parse_result;
protocols::ParseResult<TKey> parse_result;
parse_result.state = ParseState::kNeedsMoreData;
parse_result.end_position = 0;

Expand Down Expand Up @@ -134,7 +134,9 @@ void DataStream::ProcessBytesToFrames(message_type_t type, TStateType* state) {
keep_processing = false;
}

stat_valid_frames_ += parse_result.frame_positions.size();
for (const auto& [stream, positions] : parse_result.frame_positions) {
stat_valid_frames_ += positions.size();
}
stat_invalid_frames_ += parse_result.invalid_frames;
stat_raw_data_gaps_ += keep_processing;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ inline bool operator==(const StartEndPos& lhs, const StartEndPos& rhs) {
}

// A ParseResult returns a vector of parsed frames, and also some position markers.
template <typename TKey>
struct ParseResult {
// Positions of frame start and end positions in the source buffer.
std::vector<StartEndPos> frame_positions;
absl::flat_hash_map<TKey, std::vector<StartEndPos>> frame_positions;
// Position of where parsing ended consuming the source buffer.
// This is total bytes successfully consumed.
size_t end_position;
Expand Down Expand Up @@ -98,9 +99,9 @@ struct ParseResult {
* @return ParseResult with locations where parseable frames were found in the source buffer.
*/
template <typename TKey, typename TFrameType, typename TStateType = NoState>
ParseResult ParseFrames(message_type_t type, DataStreamBuffer* data_stream_buffer,
absl::flat_hash_map<TKey, std::deque<TFrameType>>* frames,
bool resync = false, TStateType* state = nullptr) {
ParseResult<TKey> ParseFrames(message_type_t type, DataStreamBuffer* data_stream_buffer,
absl::flat_hash_map<TKey, std::deque<TFrameType>>* frames,
bool resync = false, TStateType* state = nullptr) {
std::string_view buf = data_stream_buffer->Head();

size_t start_pos = 0;
Expand All @@ -121,32 +122,43 @@ ParseResult ParseFrames(message_type_t type, DataStreamBuffer* data_stream_buffe
buf.remove_prefix(start_pos);
}

// Parse and append new frames to the frames vector.
std::deque<TFrameType> new_frames = std::deque<TFrameType>();
ParseResult result = ParseFramesLoop(type, buf, &new_frames, state);
// Maintain a map of previous sizes.
absl::flat_hash_map<TKey, size_t> prev_sizes;
for (const auto& [stream_id, deque] : *frames) {
prev_sizes[stream_id] = deque.size();
}

VLOG(1) << absl::Substitute("Parsed $0 new frames", new_frames.size());
// Parse and append new frames to the map of stream ID to deque of frames
ParseResult<TKey> result = ParseFramesLoop(type, buf, frames, state);

// Match timestamps with the parsed frames.
for (size_t i = 0; i < result.frame_positions.size(); ++i) {
auto& f = result.frame_positions[i];
f.start += start_pos;
f.end += start_pos;

auto& msg = new_frames[i];
StatusOr<uint64_t> timestamp_ns_status =
data_stream_buffer->GetTimestamp(data_stream_buffer->position() + f.end);
LOG_IF(ERROR, !timestamp_ns_status.ok()) << timestamp_ns_status.ToString();
msg.timestamp_ns = timestamp_ns_status.ValueOr(0);
// Compute the number of newly parsed frames for each stream
size_t total_new_frames = 0;
for (const auto& [stream_id, positions] : result.frame_positions) {
total_new_frames += positions.size();
if (prev_sizes.find(stream_id) != prev_sizes.end()) {
total_new_frames -= prev_sizes[stream_id];
}
}
result.end_position += start_pos;
VLOG(1) << absl::Substitute("Parsed $0 new frames", total_new_frames);

// Parse frames into map
for (auto& frame : new_frames) {
// GetStreamID returns 0 by default if not implemented in protocol.
TKey key = GetStreamID<TKey, TFrameType>(&frame);
(*frames)[key].push_back(std::move(frame));
// Match timestamps with the parsed frames.
for (auto& [stream_id, positions] : result.frame_positions) {
size_t offset = prev_sizes[stream_id]; // Retrieve the initial offset for this stream_id

for (auto& f : positions) {
f.start += start_pos;
f.end += start_pos;

// Retrieve the message using the current offset
auto& msg = (*frames)[stream_id][offset];
offset++;
StatusOr<uint64_t> timestamp_ns_status =
data_stream_buffer->GetTimestamp(data_stream_buffer->position() + f.end);
LOG_IF(ERROR, !timestamp_ns_status.ok()) << timestamp_ns_status.ToString();
msg.timestamp_ns = timestamp_ns_status.ValueOr(0);
}
}
result.end_position += start_pos;
return result;
}

Expand All @@ -164,10 +176,11 @@ ParseResult ParseFrames(message_type_t type, DataStreamBuffer* data_stream_buffe
* @return ParseResult with locations where parseable frames were found in the source buffer.
*/
// TODO(oazizi): Convert tests to use ParseFrames() instead of ParseFramesLoop().
template <typename TFrameType, typename TStateType = NoState>
ParseResult ParseFramesLoop(message_type_t type, std::string_view buf,
std::deque<TFrameType>* frames, TStateType* state = nullptr) {
std::vector<StartEndPos> frame_positions;
template <typename TKey, typename TFrameType, typename TStateType = NoState>
ParseResult<TKey> ParseFramesLoop(message_type_t type, std::string_view buf,
absl::flat_hash_map<TKey, std::deque<TFrameType>>* frames,
TStateType* state = nullptr) {
absl::flat_hash_map<TKey, std::vector<StartEndPos>> frame_positions;
const size_t buf_size = buf.size();
ParseState s = ParseState::kSuccess;
size_t bytes_processed = 0;
Expand Down Expand Up @@ -225,12 +238,15 @@ ParseResult ParseFramesLoop(message_type_t type, std::string_view buf,
size_t end_position = bytes_processed - 1;

if (push) {
frame_positions.push_back({start_position, end_position});
// GetStreamID returns 0 by default if not implemented in protocol.
TKey key = GetStreamID<TKey, TFrameType>(&frame);
frame_positions[key].push_back({start_position, end_position});
(*frames)[key].push_back(std::move(frame));
frame_bytes += (end_position - start_position) + 1;
frames->push_back(std::move(frame));
}
}
return ParseResult{std::move(frame_positions), bytes_processed, s, invalid_count, frame_bytes};
return ParseResult<TKey>{std::move(frame_positions), bytes_processed, s, invalid_count,
frame_bytes};
}

} // namespace protocols
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ TEST_F(EventParserTest, BasicProtocolParsing) {
std::vector<SocketDataEvent> events = CreateEvents(event_messages);

AddEvents(events);
ParseResult res = ParseFrames(message_type_t::kRequest, &data_buffer_, &word_frames);
ParseResult<stream_id_t> res = ParseFrames(message_type_t::kRequest, &data_buffer_, &word_frames);

EXPECT_EQ(ParseState::kSuccess, res.state);
EXPECT_THAT(res.frame_positions,
stream_id_t stream_id = 0;
EXPECT_THAT(res.frame_positions[stream_id],
ElementsAre(StartEndPos{0, 7}, StartEndPos{8, 14}, StartEndPos{15, 22},
StartEndPos{23, 29}, StartEndPos{30, 35}, StartEndPos{36, 43}));
EXPECT_EQ(res.end_position, 44);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <absl/container/flat_hash_map.h>
#include "src/stirling/source_connectors/socket_tracer/protocols/common/test_utils.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/cql/parse.h"

namespace px {
Expand Down Expand Up @@ -63,29 +65,33 @@ class CQLParserTest : public ::testing::Test {};
TEST_F(CQLParserTest, Basic) {
auto frame_view = CreateStringView<char>(CharArrayStringView<uint8_t>(kQueryFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kSuccess);
ASSERT_EQ(frames.size(), 1);
EXPECT_EQ(frames[0].hdr.version & 0x80, 0);
EXPECT_EQ(frames[0].hdr.version & 0x7f, 4);
EXPECT_EQ(frames[0].hdr.flags, 0);
EXPECT_EQ(frames[0].hdr.stream, 6);
EXPECT_EQ(frames[0].hdr.opcode, Opcode::kQuery);
EXPECT_EQ(frames[0].hdr.length, 60);
EXPECT_THAT(frames[0].msg, testing::HasSubstr("SELECT * FROM system.schema_keyspaces ;"));
ASSERT_EQ(TotalDequeSize(frames), 1);
std::deque<Frame> expected_stream = frames[6];
EXPECT_EQ(expected_stream[0].hdr.version & 0x80, 0);
EXPECT_EQ(expected_stream[0].hdr.version & 0x7f, 4);
EXPECT_EQ(expected_stream[0].hdr.flags, 0);
EXPECT_EQ(expected_stream[0].hdr.stream, 6);
EXPECT_EQ(expected_stream[0].hdr.opcode, Opcode::kQuery);
EXPECT_EQ(expected_stream[0].hdr.length, 60);
EXPECT_THAT(expected_stream[0].msg,
testing::HasSubstr("SELECT * FROM system.schema_keyspaces ;"));
}

TEST_F(CQLParserTest, NeedsMoreData) {
std::string_view frame_view = CreateStringView<char>(CharArrayStringView<uint8_t>(kQueryFrame));
frame_view.remove_suffix(10);

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kNeedsMoreData);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

TEST_F(CQLParserTest, BadOpcode) {
Expand All @@ -95,11 +101,12 @@ TEST_F(CQLParserTest, BadOpcode) {
std::string_view frame_view =
CreateStringView<char>(CharArrayStringView<uint8_t>(kBadOpcodeFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kInvalid);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

TEST_F(CQLParserTest, LengthTooLarge) {
Expand All @@ -110,11 +117,12 @@ TEST_F(CQLParserTest, LengthTooLarge) {
std::string_view frame_view =
CreateStringView<char>(CharArrayStringView<uint8_t>(kBadLengthFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kInvalid);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

TEST_F(CQLParserTest, LengthNegative) {
Expand All @@ -125,11 +133,12 @@ TEST_F(CQLParserTest, LengthNegative) {
std::string_view frame_view =
CreateStringView<char>(CharArrayStringView<uint8_t>(kBadLengthFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kInvalid);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

TEST_F(CQLParserTest, VersionTooOld) {
Expand All @@ -140,11 +149,12 @@ TEST_F(CQLParserTest, VersionTooOld) {
std::string_view frame_view =
CreateStringView<char>(CharArrayStringView<uint8_t>(kBadLengthFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kInvalid);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

TEST_F(CQLParserTest, VersionTooNew) {
Expand All @@ -155,11 +165,12 @@ TEST_F(CQLParserTest, VersionTooNew) {
std::string_view frame_view =
CreateStringView<char>(CharArrayStringView<uint8_t>(kBadLengthFrame));

std::deque<Frame> frames;
ParseResult parse_result = ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);
absl::flat_hash_map<stream_id_t, std::deque<Frame>> frames;
ParseResult<stream_id_t> parse_result =
ParseFramesLoop(message_type_t::kRequest, frame_view, &frames);

ASSERT_EQ(parse_result.state, ParseState::kInvalid);
ASSERT_EQ(frames.size(), 0);
ASSERT_EQ(TotalDequeSize(frames), 0);
}

} // namespace cass
Expand Down
Loading

0 comments on commit 4926254

Please sign in to comment.