Skip to content

Commit

Permalink
WIP - Server callback + SSE server handler (Compiles, tests failing)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrit committed Feb 7, 2025
1 parent 600e4e7 commit b21f7aa
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 220 deletions.
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/mcp-transport-sse/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ server = ["dep:axum", "dep:tower-http"]
[dependencies]
mcp-core = { path = "../mcp-core" }
mcp-types = { path = "../mcp-types" }
async-stream = "0.3.6"
async-trait = { workspace = true }
axum = { version = "0.8.1", optional = true }
futures = "0.3.31"
Expand Down
11 changes: 1 addition & 10 deletions crates/mcp-transport-sse/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,7 @@ mod tests {
};

let mut transport = SSEClientTransport::new(params).unwrap();

// Create mock channel
let sent_messages = Arc::new(Mutex::new(Vec::new()));
let messages_to_receive = Arc::new(Mutex::new(vec![JSONRPCMessage::Notification(
JSONRPCNotification {
jsonrpc: "2.0".to_string(),
method: "test".to_string(),
params: None,
},
)]));
let sent_messages = Arc::new(Mutex::new(Vec::<JSONRPCMessage>::new()));

transport.started = true;

Expand Down
1 change: 0 additions & 1 deletion crates/mcp-transport-sse/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use mcp_core::transport::CallbackFnWithArg;
use mcp_types::JSONRPCMessage;

#[derive(Debug, thiserror::Error)]
Expand Down
81 changes: 47 additions & 34 deletions crates/mcp-transport-sse/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ impl SSEClientTransportParams {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SSEServerTransportParams {
pub listen_addr: String,
pub ws_url: String,
pub events_addr: String,
}

impl Default for SSEServerTransportParams {
fn default() -> Self {
Self {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
}
}
}
Expand All @@ -82,13 +82,16 @@ impl SSEServerTransportParams {
Err(e) => return Err(format!("Invalid listen address: {}", e)),
};

// Validate ws_url
if self.ws_url.is_empty() {
return Err("ws_url is required".to_string());
// Validate events_addr
if self.events_addr.is_empty() {
return Err("events_addr is required".to_string());
}

if !SSEClientTransportParams::is_valid_url(&self.ws_url) {
return Err(format!("ws_url is not a valid URL: {}", self.ws_url));
if !SSEClientTransportParams::is_valid_url(&self.events_addr) {
return Err(format!(
"events_addr is not a valid URL: {}",
self.events_addr
));
}

Ok(())
Expand All @@ -101,45 +104,38 @@ mod tests {

#[test]
fn test_is_valid_url() {
assert!(SSEClientTransportParams::is_valid_url(
"ws://localhost:8080"
));
assert!(SSEClientTransportParams::is_valid_url(
"http://localhost:8080"
));
assert!(SSEClientTransportParams::is_valid_url(
"https://localhost:8080"
));
assert!(SSEClientTransportParams::is_valid_url(
"wss://localhost:8080"
));
assert!(SSEClientTransportParams::is_valid_url("ws://google.com"));
assert!(SSEClientTransportParams::is_valid_url("http://google.com"));
assert!(SSEClientTransportParams::is_valid_url("https://google.com"));
assert!(SSEClientTransportParams::is_valid_url("wss://google.com"));
assert!(!SSEClientTransportParams::is_valid_url("ws://"));
assert!(!SSEClientTransportParams::is_valid_url("wss://google.com"));
assert!(!SSEClientTransportParams::is_valid_url("ws://0.0.0.0"));
assert!(!SSEClientTransportParams::is_valid_url("http://"));
assert!(!SSEClientTransportParams::is_valid_url("https://"));
assert!(!SSEClientTransportParams::is_valid_url("wss://"));
assert!(!SSEClientTransportParams::is_valid_url("wss://localhost"));
}

#[test]
fn test_server_transport_params_validation() {
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

let params = SSEServerTransportParams {
listen_addr: "".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_err());

let params = SSEServerTransportParams {
listen_addr: "invalid".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_err());
}
Expand All @@ -149,33 +145,33 @@ mod tests {
// Test valid ports
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:1".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:65535".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

// Test port 0 (valid for testing - system assigns random port)
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:0".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

// Test invalid ports
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:65536".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_err());
assert_eq!(
Expand All @@ -185,7 +181,7 @@ mod tests {

let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:99999".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_err());
assert_eq!(
Expand All @@ -196,30 +192,47 @@ mod tests {

#[test]
fn test_server_transport_params_ws_url_validation() {
// Test valid ws_url
// Test valid http urls
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
events_addr: "http://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "ws://localhost:8080".to_string(),
events_addr: "https://localhost:8080".to_string(),
};
assert!(params.validate().is_ok());

// Test empty ws_url
// Test empty events_addr
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "".to_string(),
events_addr: "".to_string(),
};
assert!(params.validate().is_err());
assert_eq!(params.validate().unwrap_err(), "ws_url is required");
assert_eq!(params.validate().unwrap_err(), "events_addr is required");

// Test invalid events_addr
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
events_addr: "invalid".to_string(),
};
assert!(params.validate().is_err());
assert_eq!(
params.validate().unwrap_err(),
"events_addr is not a valid URL: invalid"
);

// Test invalid ws_url
// Test websocket URLs are now invalid
let params = SSEServerTransportParams {
listen_addr: "127.0.0.1:8080".to_string(),
ws_url: "invalid".to_string(),
events_addr: "ws://localhost:8080".to_string(),
};
assert!(params.validate().is_err());
assert_eq!(
params.validate().unwrap_err(),
"ws_url is not a valid URL: invalid"
"events_addr is not a valid URL: ws://localhost:8080"
);
}
}
Loading

0 comments on commit b21f7aa

Please sign in to comment.