Skip to content

Commit

Permalink
Move MoqtSessionPeer into its own file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699983051
  • Loading branch information
vasilvv authored and copybara-github committed Nov 25, 2024
1 parent 464d61c commit 88f533f
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 142 deletions.
1 change: 1 addition & 0 deletions build/source_list.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ moqt_hdrs = [
"quic/moqt/moqt_session.h",
"quic/moqt/moqt_subscribe_windows.h",
"quic/moqt/moqt_track.h",
"quic/moqt/test_tools/moqt_session_peer.h",
"quic/moqt/test_tools/moqt_simulator_harness.h",
"quic/moqt/test_tools/moqt_test_message.h",
"quic/moqt/tools/chat_client.h",
Expand Down
1 change: 1 addition & 0 deletions build/source_list.gni
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,7 @@ moqt_hdrs = [
"src/quiche/quic/moqt/moqt_session.h",
"src/quiche/quic/moqt/moqt_subscribe_windows.h",
"src/quiche/quic/moqt/moqt_track.h",
"src/quiche/quic/moqt/test_tools/moqt_session_peer.h",
"src/quiche/quic/moqt/test_tools/moqt_simulator_harness.h",
"src/quiche/quic/moqt/test_tools/moqt_test_message.h",
"src/quiche/quic/moqt/tools/chat_client.h",
Expand Down
1 change: 1 addition & 0 deletions build/source_list.json
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,7 @@
"quiche/quic/moqt/moqt_session.h",
"quiche/quic/moqt/moqt_subscribe_windows.h",
"quiche/quic/moqt/moqt_track.h",
"quiche/quic/moqt/test_tools/moqt_session_peer.h",
"quiche/quic/moqt/test_tools/moqt_simulator_harness.h",
"quiche/quic/moqt/test_tools/moqt_test_message.h",
"quiche/quic/moqt/tools/chat_client.h",
Expand Down
144 changes: 2 additions & 142 deletions quiche/quic/moqt/moqt_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <utility>
#include <vector>


#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
Expand All @@ -26,6 +25,7 @@
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/test_tools/moqt_session_peer.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/quic/test_tools/quic_test_utils.h"
Expand All @@ -46,7 +46,6 @@ using ::testing::Invoke;
using ::testing::Return;
using ::testing::StrictMock;

constexpr webtransport::StreamId kControlStreamId = 4;
constexpr webtransport::StreamId kIncomingUniStreamId = 15;
constexpr webtransport::StreamId kOutgoingUniStreamId = 14;

Expand Down Expand Up @@ -76,145 +75,6 @@ static std::shared_ptr<MockTrackPublisher> SetupPublisher(

} // namespace

class MockFetchTask : public MoqtFetchTask {
public:
MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
(PublishedObject & output), (override));
MOCK_METHOD(absl::Status, GetStatus, (), (override));
MOCK_METHOD(FullSequence, GetLargestId, (), (const, override));

void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
callback_ = std::move(callback);
}

ObjectsAvailableCallback callback_;
};

class MoqtSessionPeer {
public:
static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream(
MoqtSession* session, webtransport::test::MockStream* stream) {
auto new_stream =
std::make_unique<MoqtSession::ControlStream>(session, stream);
session->control_stream_ = kControlStreamId;
ON_CALL(*stream, visitor()).WillByDefault(Return(new_stream.get()));
webtransport::test::MockSession* mock_session =
static_cast<webtransport::test::MockSession*>(session->session());
EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
.Times(AnyNumber())
.WillRepeatedly(Return(stream));
return new_stream;
}

static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
MoqtSession* session, webtransport::Stream* stream) {
auto new_stream =
std::make_unique<MoqtSession::IncomingDataStream>(session, stream);
return new_stream;
}

// In the test OnSessionReady, the session creates a stream and then passes
// its unique_ptr to the mock webtransport stream. This function casts
// that unique_ptr into a MoqtSession::Stream*, which is a private class of
// MoqtSession, and then casts again into MoqtParserVisitor so that the test
// can inject packets into that stream.
// This function is useful for any test that wants to inject packets on a
// stream created by the MoqtSession.
static MoqtControlParserVisitor*
FetchParserVisitorFromWebtransportStreamVisitor(
MoqtSession* session, webtransport::StreamVisitor* visitor) {
return static_cast<MoqtSession::ControlStream*>(visitor);
}

static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
RemoteTrack::Visitor* visitor,
uint64_t track_alias) {
session->remote_tracks_.try_emplace(track_alias, name, track_alias,
visitor);
session->remote_track_aliases_.try_emplace(name, track_alias);
}

static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
MoqtSubscribe& subscribe,
RemoteTrack::Visitor* visitor) {
session->active_subscribes_[subscribe_id] = {subscribe, visitor};
}

static MoqtObjectListener* AddSubscription(
MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher,
uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group,
uint64_t start_object) {
MoqtSubscribe subscribe;
subscribe.full_track_name = publisher->GetTrackName();
subscribe.track_alias = track_alias;
subscribe.subscribe_id = subscribe_id;
subscribe.start_group = start_group;
subscribe.start_object = start_object;
subscribe.subscriber_priority = 0x80;
session->published_subscriptions_.emplace(
subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>(
session, std::move(publisher), subscribe,
/*monitoring_interface=*/nullptr));
return session->published_subscriptions_[subscribe_id].get();
}

static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) {
session->published_subscriptions_.erase(subscribe_id);
}

static void UpdateSubscriberPriority(MoqtSession* session,
uint64_t subscribe_id,
MoqtPriority priority) {
session->published_subscriptions_[subscribe_id]->set_subscriber_priority(
priority);
}

static void set_peer_role(MoqtSession* session, MoqtRole role) {
session->peer_role_ = role;
}

static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
return session->remote_tracks_.find(track_alias)->second;
}

static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
session->next_subscribe_id_ = id;
}

static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
session->peer_max_subscribe_id_ = id;
}

static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) {
auto fetch_task = std::make_unique<MockFetchTask>();
MockFetchTask* return_ptr = fetch_task.get();
auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>(
fetch_id, session, std::move(fetch_task));
session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch));
// Add the fetch to the pending stream queue.
session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0);
return return_ptr;
}

static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session,
uint64_t fetch_id) {
auto it = session->incoming_fetches_.find(fetch_id);
if (it == session->incoming_fetches_.end()) {
return nullptr;
}
return it->second.get();
}

static void ValidateSubscribeId(MoqtSession* session, uint64_t id) {
session->ValidateSubscribeId(id);
}

static FullSequence LargestSentForSubscription(MoqtSession* session,
uint64_t subscribe_id) {
return *session->published_subscriptions_[subscribe_id]->largest_sent();
}
};

class MoqtSessionTest : public quic::test::QuicTest {
public:
MoqtSessionTest()
Expand Down Expand Up @@ -2191,7 +2051,7 @@ TEST_F(MoqtSessionTest, FetchReturnsOkImmediateOpen) {
MoqtDataStreamType::kStreamHeaderFetch));
return absl::OkStatus();
});
fetch_task->callback_();
fetch_task->objects_available_callback()();
EXPECT_TRUE(correct_message);
}

Expand Down
156 changes: 156 additions & 0 deletions quiche/quic/moqt/test_tools/moqt_session_peer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2023 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_session.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/web_transport/test_tools/mock_web_transport.h"
#include "quiche/web_transport/web_transport.h"

namespace moqt::test {

class MoqtSessionPeer {
public:
static constexpr webtransport::StreamId kControlStreamId = 4;

static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream(
MoqtSession* session, webtransport::test::MockStream* stream) {
auto new_stream =
std::make_unique<MoqtSession::ControlStream>(session, stream);
session->control_stream_ = kControlStreamId;
ON_CALL(*stream, visitor())
.WillByDefault(::testing::Return(new_stream.get()));
webtransport::test::MockSession* mock_session =
static_cast<webtransport::test::MockSession*>(session->session());
EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
.Times(::testing::AnyNumber())
.WillRepeatedly(::testing::Return(stream));
return new_stream;
}

static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
MoqtSession* session, webtransport::Stream* stream) {
auto new_stream =
std::make_unique<MoqtSession::IncomingDataStream>(session, stream);
return new_stream;
}

// In the test OnSessionReady, the session creates a stream and then passes
// its unique_ptr to the mock webtransport stream. This function casts
// that unique_ptr into a MoqtSession::Stream*, which is a private class of
// MoqtSession, and then casts again into MoqtParserVisitor so that the test
// can inject packets into that stream.
// This function is useful for any test that wants to inject packets on a
// stream created by the MoqtSession.
static MoqtControlParserVisitor*
FetchParserVisitorFromWebtransportStreamVisitor(
MoqtSession* session, webtransport::StreamVisitor* visitor) {
return static_cast<MoqtSession::ControlStream*>(visitor);
}

static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
RemoteTrack::Visitor* visitor,
uint64_t track_alias) {
session->remote_tracks_.try_emplace(track_alias, name, track_alias,
visitor);
session->remote_track_aliases_.try_emplace(name, track_alias);
}

static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
MoqtSubscribe& subscribe,
RemoteTrack::Visitor* visitor) {
session->active_subscribes_[subscribe_id] = {subscribe, visitor};
}

static MoqtObjectListener* AddSubscription(
MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher,
uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group,
uint64_t start_object) {
MoqtSubscribe subscribe;
subscribe.full_track_name = publisher->GetTrackName();
subscribe.track_alias = track_alias;
subscribe.subscribe_id = subscribe_id;
subscribe.start_group = start_group;
subscribe.start_object = start_object;
subscribe.subscriber_priority = 0x80;
session->published_subscriptions_.emplace(
subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>(
session, std::move(publisher), subscribe,
/*monitoring_interface=*/nullptr));
return session->published_subscriptions_[subscribe_id].get();
}

static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) {
session->published_subscriptions_.erase(subscribe_id);
}

static void UpdateSubscriberPriority(MoqtSession* session,
uint64_t subscribe_id,
MoqtPriority priority) {
session->published_subscriptions_[subscribe_id]->set_subscriber_priority(
priority);
}

static void set_peer_role(MoqtSession* session, MoqtRole role) {
session->peer_role_ = role;
}

static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
return session->remote_tracks_.find(track_alias)->second;
}

static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
session->next_subscribe_id_ = id;
}

static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
session->peer_max_subscribe_id_ = id;
}

static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) {
auto fetch_task = std::make_unique<MockFetchTask>();
MockFetchTask* return_ptr = fetch_task.get();
auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>(
fetch_id, session, std::move(fetch_task));
session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch));
// Add the fetch to the pending stream queue.
session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0);
return return_ptr;
}

static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session,
uint64_t fetch_id) {
auto it = session->incoming_fetches_.find(fetch_id);
if (it == session->incoming_fetches_.end()) {
return nullptr;
}
return it->second.get();
}

static void ValidateSubscribeId(MoqtSession* session, uint64_t id) {
session->ValidateSubscribeId(id);
}

static FullSequence LargestSentForSubscription(MoqtSession* session,
uint64_t subscribe_id) {
return *session->published_subscriptions_[subscribe_id]->largest_sent();
}
};

} // namespace moqt::test

#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
19 changes: 19 additions & 0 deletions quiche/quic/moqt/tools/moqt_mock_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_time.h"
Expand Down Expand Up @@ -100,6 +101,24 @@ class MockPublishingMonitorInterface : public MoqtPublishingMonitorInterface {
(override));
};

class MockFetchTask : public MoqtFetchTask {
public:
MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
(PublishedObject & output), (override));
MOCK_METHOD(absl::Status, GetStatus, (), (override));
MOCK_METHOD(FullSequence, GetLargestId, (), (const, override));

void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
objects_available_callback_ = std::move(callback);
}
ObjectsAvailableCallback& objects_available_callback() {
return objects_available_callback_;
};

private:
ObjectsAvailableCallback objects_available_callback_;
};

} // namespace moqt::test

#endif // QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_

0 comments on commit 88f533f

Please sign in to comment.