Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for MQTT stitcher in stirling #1918

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,6 @@ namespace protocols {

namespace mqtt {

// This is modeling a 4 bit field specifying the control packet type
enum class MqttControlPacketType : uint8_t {
CONNECT = 1,
CONNACK = 2,
PUBLISH = 3,
PUBACK = 4,
PUBREC = 5,
PUBREL = 6,
PUBCOMP = 7,
SUBSCRIBE = 8,
SUBACK = 9,
UNSUBSCRIBE = 10,
UNSUBACK = 11,
PINGREQ = 12,
PINGRESP = 13,
DISCONNECT = 14,
AUTH = 15
};

enum class PropertyCode : uint8_t {
PayloadFormatIndicator = 0x01,
MessageExpiryInterval = 0x02,
Expand Down Expand Up @@ -654,7 +635,8 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder,
}
}

ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result) {
ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result,
mqtt::StateWrapper* state) {
CTX_DCHECK(type == message_type_t::kRequest || type == message_type_t::kResponse);
if (buf->size() < 2) {
return ParseState::kNeedsMoreData;
Expand Down Expand Up @@ -724,6 +706,27 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul
return ParseState::kInvalid;
}

// Updating the state for PUBLISH based on whether it is duplicate
if (control_packet_type == MqttControlPacketType::PUBLISH) {
if (result->dup) {
if (type == message_type_t::kRequest) {
state->send[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] += 1;
} else {
state->recv[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] += 1;
}
} else {
if (type == message_type_t::kRequest) {
state->send[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] = 0;
} else {
state->recv[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] = 0;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to store a key in the state for PUBLISH'es that don't have any dups?

The ::operator[] method will insert into the map if the key doesn't exist. So I believe the += operator will work as expected without the non dup branch.

}

if (ParsePayload(result, &decoder, control_packet_type) == ParseState::kInvalid) {
return ParseState::kInvalid;
}
Expand All @@ -736,8 +739,8 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul

template <>
ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* result,
NoState* /*state*/) {
return mqtt::ParseFrame(type, buf, result);
mqtt::StateWrapper* state) {
return mqtt::ParseFrame(type, buf, result, state);
}

template <>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ namespace protocols {

template <>
ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* frame,
NoState* state);
mqtt::StateWrapper* state);

template <>
size_t FindFrameBoundary<mqtt::Message>(message_type_t type, std::string_view buf, size_t start_pos,
NoState* state);
mqtt::StateWrapper* state);

template <>
mqtt::packet_id_t GetStreamID(mqtt::Message* message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace stirling {
namespace protocols {
namespace mqtt {

// MatchKey layout, || control_packet_type (4 bits) | dup (1 bit) | qos (2 bits) | retain (1 bit) ||
typedef uint8_t MatchKey;
ddelnano marked this conversation as resolved.
Show resolved Hide resolved

constexpr MatchKey UnmatchedResp = 0xff;
Expand Down Expand Up @@ -64,12 +65,10 @@ std::map<MatchKey, MatchKey> MapRequestToResponse = {
{0xa0, 0xb0},
// PINGREQ to PINGRESP
{0xc0, 0xd0},
// AUTH to AUTH
{0xf0, 0xf0},
// DISCONNECT to Dummy response
{0xe0, UnmatchedResp},
// AUTH to Dummy response
{0xf0, UnmatchedResp}};

std::set<std::tuple<uint32_t, uint32_t>> UnansweredPublish;
{0xe0, UnmatchedResp}};

// Possible to have the server sending PUBLISH with same packet identifier as client PUBLISH before
// it sends PUBACK, causing server PUBLISH to be put into response deque instead of request deque.
Expand All @@ -82,7 +81,7 @@ MatchKey getMatchKey(mqtt::Message& frame) {

RecordsWithErrorCount<Record> StitchFrames(
absl::flat_hash_map<packet_id_t, std::deque<Message>>* req_frames,
absl::flat_hash_map<packet_id_t, std::deque<Message>>* resp_frames) {
absl::flat_hash_map<packet_id_t, std::deque<Message>>* resp_frames, mqtt::StateWrapper* state) {
std::vector<Record> entries;
int error_count = 0;

Expand All @@ -104,19 +103,24 @@ RecordsWithErrorCount<Record> StitchFrames(
// finding the closest appropriate response from response deque in terms of timestamp and type
// for each request in the request deque
for (mqtt::Message& req_frame : req_deque) {
// if the request is a PUBLISH (QOS 1 or QOS 2) with dup false, entry needs to be made in
// UnansweredPublish
if (req_frame.control_packet_type == 3 && req_frame.header_fields["qos"] != 0 &&
!req_frame.dup) {
UnansweredPublish.insert(std::tuple<uint32_t, uint32_t>(
req_frame.header_fields["packet_identifier"], req_frame.header_fields["qos"]));
}
// if the request is a duplicate PUBLISH, find out if the original PUBLISH has been matched
if (req_frame.control_packet_type == 3 && req_frame.header_fields["qos"] != 0 &&
req_frame.dup) {
if (UnansweredPublish.find(std::tuple<uint32_t, uint32_t>(
req_frame.header_fields["packet_identifier"], req_frame.header_fields["qos"])) ==
UnansweredPublish.end()) {
const MqttControlPacketType control_packet_type =
magic_enum::enum_cast<MqttControlPacketType>(req_frame.control_packet_type).value();
// If the frame is PUBLISH, and there are duplicates in the deque, then mark the frame as
// consumed and match the latest duplicate with its response (if the response exists in the
// response deque)
if (control_packet_type == MqttControlPacketType::PUBLISH) {
std::tuple<uint32_t, uint32_t> unique_publish_identifier = std::tuple<uint32_t, uint32_t>(
req_frame.header_fields["packet_identifier"], req_frame.header_fields["qos"]);
if (req_frame.type == message_type_t::kRequest &&
state->send[unique_publish_identifier] > 0) {
state->send[unique_publish_identifier] -= 1;
req_frame.consumed = true;
continue;
}

if (req_frame.type == message_type_t::kResponse &&
state->recv[unique_publish_identifier] > 0) {
state->recv[unique_publish_identifier] -= 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my earlier comment about avoiding excess inserts into state->{recv,send}. Can we use ::find() instead of ::operator[]?

req_frame.consumed = true;
continue;
}
Expand Down Expand Up @@ -165,17 +169,6 @@ RecordsWithErrorCount<Record> StitchFrames(
}
mqtt::Message& resp_frame = *response_frame_iter;

// if the response is PUBACK/PUBREC, then remove the associated (packet_identifier, qos) tuple
// from the UnansweredPublish set
if (resp_frame.control_packet_type == 4) {
UnansweredPublish.erase(
std::tuple<uint64_t, uint64_t>(resp_frame.header_fields["packet_identifier"], 1));
}
if (resp_frame.control_packet_type == 5) {
UnansweredPublish.erase(
std::tuple<uint64_t, uint64_t>(resp_frame.header_fields["packet_identifier"], 2));
}

req_frame.consumed = true;
resp_frame.consumed = true;
entries.push_back({std::move(req_frame), std::move(resp_frame)});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,22 @@ namespace mqtt {
*
* @param req_frames: deque of all request frames.
* @param resp_frames: deque of all response frames.
* @param resp_frames: state holding send and recv maps, which are key-value pairs of (packet id,
* qos) and dup counter.
* @return A vector of entries to be appended to table store.
*/
RecordsWithErrorCount<Record> StitchFrames(
absl::flat_hash_map<packet_id_t, std::deque<Message>>* req_frames,
absl::flat_hash_map<packet_id_t, std::deque<Message>>* resp_frames);
absl::flat_hash_map<packet_id_t, std::deque<Message>>* resp_frames, mqtt::StateWrapper* state);

} // namespace mqtt

template <>
inline RecordsWithErrorCount<mqtt::Record> StitchFrames(
absl::flat_hash_map<mqtt::packet_id_t, std::deque<mqtt::Message>>* req_messages,
absl::flat_hash_map<mqtt::packet_id_t, std::deque<mqtt::Message>>* res_messages,
NoState* /* state */) {
return mqtt::StitchFrames(req_messages, res_messages);
mqtt::StateWrapper* state) {
return mqtt::StitchFrames(req_messages, res_messages, state);
}

} // namespace protocols
Expand Down
Loading
Loading