Skip to content

Commit

Permalink
agent: refactor Connectors to DiscoverConnectors
Browse files Browse the repository at this point in the history
Refactors the `Connectors` trait to be more high-level and specific to
Discovers.  Also refactors the `MockConnectors` that are used by integration
tests to simplify that, since it's not a more specific trait.
  • Loading branch information
psFried committed Nov 12, 2024
1 parent fee595c commit 4b00d17
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 175 deletions.
8 changes: 4 additions & 4 deletions crates/agent/src/controlplane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
DefaultRetryPolicy, DraftPublication, NoopFinalize, PublicationResult, Publisher,
UpdateInferredSchemas,
},
Connectors, DiscoverHandler,
DiscoverConnectors, DiscoverHandler,
};

macro_rules! unwrap_single {
Expand Down Expand Up @@ -167,15 +167,15 @@ fn set_of<T: Into<String>>(s: T) -> BTreeSet<String> {

/// Implementation of `ControlPlane` that connects directly to postgres.
#[derive(Clone)]
pub struct PGControlPlane<C: Connectors> {
pub struct PGControlPlane<C: DiscoverConnectors> {
pub pool: sqlx::PgPool,
pub system_user_id: Uuid,
pub publications_handler: Publisher,
pub id_generator: models::IdGenerator,
pub discovers_handler: DiscoverHandler<C>,
}

impl<C: Connectors> PGControlPlane<C> {
impl<C: DiscoverConnectors> PGControlPlane<C> {
pub fn new(
pool: sqlx::PgPool,
system_user_id: Uuid,
Expand Down Expand Up @@ -255,7 +255,7 @@ impl<C: Connectors> PGControlPlane<C> {
}

#[async_trait::async_trait]
impl<C: Connectors> ControlPlane for PGControlPlane<C> {
impl<C: DiscoverConnectors> ControlPlane for PGControlPlane<C> {
#[tracing::instrument(level = "debug", err, skip(self))]
async fn notify_dependents(&mut self, catalog_name: String) -> anyhow::Result<()> {
let now = self.current_time();
Expand Down
8 changes: 4 additions & 4 deletions crates/agent/src/discovers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{BTreeMap, HashSet};

use crate::proxy_connectors::Connectors;
use crate::proxy_connectors::DiscoverConnectors;

use anyhow::Context;
use models::split_image_tag;
Expand Down Expand Up @@ -158,13 +158,13 @@ pub struct DiscoverHandler<C> {
pub connectors: C,
}

impl<C: Connectors> DiscoverHandler<C> {
impl<C: DiscoverConnectors> DiscoverHandler<C> {
pub fn new(connectors: C) -> Self {
Self { connectors }
}
}

impl<C: Connectors> DiscoverHandler<C> {
impl<C: DiscoverConnectors> DiscoverHandler<C> {
#[tracing::instrument(skip_all, fields(
capture_name = %req.capture_name,
data_plane_name = %req.data_plane.data_plane_name,
Expand Down Expand Up @@ -234,7 +234,7 @@ impl<C: Connectors> DiscoverHandler<C> {

let result = self
.connectors
.unary_capture(request, logs_token, task, &data_plane)
.discover(request, logs_token, task, &data_plane)
.await;

let response = match result {
Expand Down
6 changes: 3 additions & 3 deletions crates/agent/src/discovers/handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Discover, DiscoverHandler};
use crate::{draft, proxy_connectors::Connectors, HandleResult, Handler, Id};
use crate::{draft, proxy_connectors::DiscoverConnectors, HandleResult, Handler, Id};
use agent_sql::discovers::Row;
use anyhow::Context;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -36,7 +36,7 @@ impl JobStatus {
}

#[async_trait::async_trait]
impl<C: Connectors> Handler for DiscoverHandler<C> {
impl<C: DiscoverConnectors> Handler for DiscoverHandler<C> {
async fn handle(
&mut self,
pg_pool: &sqlx::PgPool,
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<C: Connectors> Handler for DiscoverHandler<C> {
}
}

impl<C: Connectors> DiscoverHandler<C> {
impl<C: DiscoverConnectors> DiscoverHandler<C> {
#[tracing::instrument(err, skip_all, fields(id=?row.id, draft_id = ?row.draft_id, user_id = %row.user_id))]
async fn process(
&mut self,
Expand Down
24 changes: 12 additions & 12 deletions crates/agent/src/integration_tests/auto_discovers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async fn test_auto_discovers_add_new_bindings() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("marmots/capture", Ok(discovered));
tokio::time::sleep(AUTO_DISCOVER_WAIT).await;
harness.run_pending_controller("marmots/capture").await;

Expand Down Expand Up @@ -226,7 +226,7 @@ async fn test_auto_discovers_add_new_bindings() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("marmots/capture", Ok(discovered));

tokio::time::sleep(AUTO_DISCOVER_WAIT).await;
harness.run_pending_controller("marmots/capture").await;
Expand Down Expand Up @@ -390,7 +390,7 @@ async fn test_auto_discovers_no_evolution() {
draft_id,
r#"{"hee": "hawwww"}"#,
false,
Box::new(Ok(discovered.clone())),
Ok(discovered.clone()),
)
.await;
assert!(result.job_status.is_success());
Expand Down Expand Up @@ -418,7 +418,7 @@ async fn test_auto_discovers_no_evolution() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(new_discovered)));
.mock_discover("mules/capture", Ok(new_discovered));
harness.run_pending_controller("mules/capture").await;

let capture_state = harness.get_controller_state("mules/capture").await;
Expand Down Expand Up @@ -504,7 +504,7 @@ async fn test_auto_discovers_no_evolution() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("mules/capture", Ok(discovered));
harness.run_pending_controller("mules/capture").await;

let capture_state = harness.get_controller_state("mules/capture").await;
Expand Down Expand Up @@ -693,7 +693,7 @@ async fn test_auto_discovers_update_only() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("pikas/capture", Ok(discovered));

harness.run_pending_controller("pikas/capture").await;
let capture_state = harness.get_controller_state("pikas/capture").await;
Expand Down Expand Up @@ -742,10 +742,10 @@ async fn test_auto_discovers_update_only() {
assert!(last_success.publish_result.is_none());

// Now simulate a discover error, and expect to see the error status reported.
harness
.discover_handler
.connectors
.mock_discover(Box::new(Err("a simulated discover error".to_string())));
harness.discover_handler.connectors.mock_discover(
"pikas/capture",
Err("a simulated discover error".to_string()),
);
tokio::time::sleep(AUTO_DISCOVER_WAIT).await;
harness.run_pending_controller("pikas/capture").await;

Expand Down Expand Up @@ -791,7 +791,7 @@ async fn test_auto_discovers_update_only() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("pikas/capture", Ok(discovered));
harness.control_plane().fail_next_build(
"pikas/capture",
InjectBuildError::new(
Expand Down Expand Up @@ -860,7 +860,7 @@ async fn test_auto_discovers_update_only() {
harness
.discover_handler
.connectors
.mock_discover(Box::new(Ok(discovered)));
.mock_discover("pikas/capture", Ok(discovered));
tokio::time::sleep(AUTO_DISCOVER_WAIT).await;
harness.run_pending_controller("pikas/capture").await;

Expand Down
12 changes: 6 additions & 6 deletions crates/agent/src/integration_tests/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use sqlx::types::Uuid;
use tables::DraftRow;
use tempfile::tempdir;

use self::connectors::MockConnectors;
use self::connectors::MockDiscoverConnectors;

const FIXED_DATABASE_URL: &str = "postgresql://postgres:postgres@localhost:5432/postgres";

Expand Down Expand Up @@ -116,7 +116,7 @@ pub struct TestHarness {
#[allow(dead_code)] // only here so we don't drop it until the harness is dropped
pub builds_root: tempfile::TempDir,
pub controllers: ControllerHandler<TestControlPlane>,
pub discover_handler: DiscoverHandler<connectors::MockConnectors>,
pub discover_handler: DiscoverHandler<connectors::MockDiscoverConnectors>,
}

impl TestHarness {
Expand Down Expand Up @@ -150,7 +150,7 @@ impl TestHarness {
});

let id_gen = models::IdGenerator::new(1);
let mock_connectors = connectors::MockConnectors::default();
let mock_connectors = connectors::MockDiscoverConnectors::default();
let discover_handler = DiscoverHandler::new(mock_connectors);

let publisher = Publisher::new(
Expand Down Expand Up @@ -761,7 +761,7 @@ impl TestHarness {

self.discover_handler
.connectors
.mock_discover(mock_discover_resp);
.mock_discover(capture_name, mock_discover_resp);

let result = self
.discover_handler
Expand Down Expand Up @@ -1065,7 +1065,7 @@ impl FailBuild for InjectBuildError {
/// A wrapper around `PGControlPlane` that has a few basic capbilities for verifying
/// activation calls and simulating failures of activations and publications.
pub struct TestControlPlane {
inner: PGControlPlane<MockConnectors>,
inner: PGControlPlane<MockDiscoverConnectors>,
activations: Vec<Activation>,
fail_activations: BTreeSet<String>,
build_failures: InjectBuildFailures,
Expand Down Expand Up @@ -1098,7 +1098,7 @@ impl crate::publications::FinalizeBuild for InjectBuildFailures {
}

impl TestControlPlane {
fn new(inner: PGControlPlane<MockConnectors>) -> Self {
fn new(inner: PGControlPlane<MockDiscoverConnectors>) -> Self {
Self {
inner,
activations: Vec::new(),
Expand Down
124 changes: 33 additions & 91 deletions crates/agent/src/integration_tests/harness/connectors.rs
Original file line number Diff line number Diff line change
@@ -1,112 +1,54 @@
use crate::proxy_connectors::Connectors;
use crate::proxy_connectors::DiscoverConnectors;
use proto_flow::capture;
use std::fmt::Debug;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

pub trait MockCall<Req, Resp>: Send + Sync + 'static {
fn call(
&self,
req: Req,
logs_token: uuid::Uuid,
task: ops::ShardRef,
data_plane: &tables::DataPlane,
) -> anyhow::Result<Resp>;
}

impl<Req, Resp> MockCall<Req, Resp> for Result<Resp, String>
where
Resp: Clone + Send + Sync + 'static,
{
fn call(
&self,
_req: Req,
_logs_token: uuid::Uuid,
_task: ops::ShardRef,
_data_plane: &tables::DataPlane,
) -> anyhow::Result<Resp> {
self.clone().map_err(anyhow::Error::msg)
}
}

struct DefaultFail;
impl<Req, Resp> MockCall<Req, Resp> for DefaultFail
where
Req: Debug,
{
fn call(
&self,
req: Req,
_logs_token: uuid::Uuid,
_task: ops::ShardRef,
_data_plane: &tables::DataPlane,
) -> anyhow::Result<Resp> {
Err(anyhow::anyhow!("default mock failure for request: {req:?}"))
}
}

pub type MockDiscover =
Box<dyn MockCall<capture::request::Discover, capture::response::Discovered>>;
pub type MockDiscover = Result<capture::response::Discovered, String>;

#[derive(Clone)]
pub struct MockConnectors {
discover: Arc<Mutex<MockDiscover>>,
pub struct MockDiscoverConnectors {
mocks: Arc<Mutex<HashMap<models::Capture, MockDiscover>>>,
}

impl Default for MockConnectors {
impl Default for MockDiscoverConnectors {
fn default() -> Self {
MockConnectors {
discover: Arc::new(Mutex::new(Box::new(DefaultFail))),
MockDiscoverConnectors {
mocks: Arc::new(Mutex::new(HashMap::new())),
}
}
}

impl MockConnectors {
pub fn mock_discover(&mut self, respond: MockDiscover) {
let mut lock = self.discover.lock().unwrap();
*lock = respond;
impl MockDiscoverConnectors {
pub fn mock_discover(&mut self, capture_name: &str, respond: MockDiscover) {
let mut lock = self.mocks.lock().unwrap();
lock.insert(models::Capture::new(capture_name), respond);
}
}

/// Currently, `MockConnectors` only supports capture Discover RPCs.
/// Publications do not yet use this for validate RPCs, but the plan is to do
/// that at some point, so that we can more easily test the publication logic.
impl Connectors for MockConnectors {
async fn unary_capture<'a>(
impl DiscoverConnectors for MockDiscoverConnectors {
async fn discover<'a>(
&'a self,
mut req: capture::Request,
logs_token: uuid::Uuid,
task: ops::ShardRef,
data_plane: &'a tables::DataPlane,
) -> anyhow::Result<capture::Response> {
if let Some(discover) = req.discover.take() {
let locked = self.discover.lock().unwrap();
return locked
.call(discover, logs_token, task, data_plane)
.map(|resp| capture::Response {
discovered: Some(resp),
..Default::default()
});
}
Err(anyhow::anyhow!("unhandled capture request type: {req:?}"))
}

async fn unary_derive<'a>(
&'a self,
_req: proto_flow::derive::Request,
_logs_token: uuid::Uuid,
_task: ops::ShardRef,
_data_plane: &'a tables::DataPlane,
) -> anyhow::Result<proto_flow::derive::Response> {
unimplemented!("mock connectors do not yet handle unary_derive calls");
}

async fn unary_materialize<'a>(
&'a self,
_req: proto_flow::materialize::Request,
_logs_token: uuid::Uuid,
_task: ops::ShardRef,
task: ops::ShardRef,
_data_plane: &'a tables::DataPlane,
) -> anyhow::Result<proto_flow::materialize::Response> {
unimplemented!("mock connectors do not yet handle unary_materialize calls");
) -> anyhow::Result<capture::Response> {
let Some(discover) = req.discover.take() else {
anyhow::bail!("unexpected capture request type: {req:?}")
};

let locked = self.mocks.lock().unwrap();
let capture = models::Capture::new(&task.name);
let Some(mock) = locked.get(&capture) else {
anyhow::bail!("no mock for capture: {capture}");
};

tracing::debug!(req = ?discover, resp = ?mock, "responding with mock discovered response");
mock.clone()
.map_err(|err_str| anyhow::anyhow!("{err_str}"))
.map(|dr| capture::Response {
discovered: Some(dr),
..Default::default()
})
}
}
Loading

0 comments on commit 4b00d17

Please sign in to comment.