Skip to content

Commit

Permalink
Merge pull request #184: Check for claimed nameplates
Browse files Browse the repository at this point in the history
  • Loading branch information
piegamesde authored Mar 8, 2023
2 parents 32369ac + 3daaded commit 46eceb0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 8 deletions.
5 changes: 3 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ async fn parse_and_connect(
)?;
}
let (server_welcome, wormhole) =
magic_wormhole::Wormhole::connect_with_code(app_config, code).await?;
magic_wormhole::Wormhole::connect_with_code(app_config, code, false).await?;
print_welcome(term, &server_welcome)?;
(wormhole, server_welcome.code)
},
Expand Down Expand Up @@ -860,7 +860,8 @@ async fn send_many(
}

let (_server_welcome, wormhole) =
magic_wormhole::Wormhole::connect_with_code(transfer::APP_CONFIG, code.clone()).await?;
magic_wormhole::Wormhole::connect_with_code(transfer::APP_CONFIG, code.clone(), false)
.await?;
send_in_background(
relay_hints.clone(),
Arc::clone(&file_path),
Expand Down
12 changes: 12 additions & 0 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub enum WormholeError {
PakeFailed,
#[error("Cannot decrypt a received message")]
Crypto,
#[error("Nameplate is unclaimed: {}", _0)]
UnclaimedNameplate(Nameplate),
}

impl WormholeError {
Expand Down Expand Up @@ -165,6 +167,7 @@ impl Wormhole {
pub async fn connect_with_code(
config: AppConfig<impl serde::Serialize + Send + Sync + 'static>,
code: Code,
expect_claimed_nameplate: bool,
) -> Result<(WormholeWelcome, Self), WormholeError> {
let AppConfig {
id: appid,
Expand All @@ -174,6 +177,15 @@ impl Wormhole {
let (mut server, welcome) = RendezvousServer::connect(&appid, &rendezvous_url).await?;

let nameplate = code.nameplate();

if expect_claimed_nameplate {
let nameplate = code.nameplate();
let nameplates = server.list_nameplates().await?;
if !nameplates.contains(&nameplate) {
return Err(WormholeError::UnclaimedNameplate(nameplate));
}
}

let mailbox = server.claim_open(nameplate).await?;
log::debug!("Connected to mailbox {}", mailbox);

Expand Down
21 changes: 21 additions & 0 deletions src/core/rendezvous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ impl RendezvousError {

type MessageQueue = VecDeque<EncryptedMessage>;

#[derive(Clone, Debug, derive_more::Display)]
#[display(fmt = "{:?}", _0)]
struct NameplateList(Vec<Nameplate>);

#[cfg(not(target_family = "wasm"))]
struct WsConnection {
connection: async_tungstenite::WebSocketStream<async_tungstenite::async_std::ConnectStream>,
Expand Down Expand Up @@ -174,6 +178,9 @@ impl WsConnection {
Some(InboundMessage::Error { error, orig: _ }) => {
break Err(RendezvousError::Server(error.into()));
},
Some(InboundMessage::Nameplates { nameplates }) => {
break Ok(RendezvousReply::Nameplates(NameplateList(nameplates)))
},
Some(other) => {
break Err(RendezvousError::protocol(format!(
"Got unexpected message type from server '{}'",
Expand Down Expand Up @@ -273,6 +280,7 @@ enum RendezvousReply {
Released,
Claimed(Mailbox),
Closed,
Nameplates(NameplateList),
}

#[derive(Clone, Debug, derive_more::Display)]
Expand Down Expand Up @@ -528,6 +536,19 @@ impl RendezvousServer {
.is_some()
}

/**
* Gets the list of currently claimed nameplates.
* This can be called at any time.
*/
pub async fn list_nameplates(&mut self) -> Result<Vec<Nameplate>, RendezvousError> {
self.send_message(&OutboundMessage::List).await?;
let nameplate_reply = self.receive_reply().await?;
match nameplate_reply {
RendezvousReply::Nameplates(x) => Ok(x.0),
other => Err(RendezvousError::invalid_message("nameplates", other)),
}
}

pub async fn release_nameplate(&mut self) -> Result<(), RendezvousError> {
let nameplate = &mut self
.state
Expand Down
38 changes: 32 additions & 6 deletions src/core/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> {
let code = code_rx.await?;
log::info!("Got code over local: {}", &code);
let (welcome, wormhole) =
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?;
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true)
.await?;
if let Some(welcome) = &welcome.welcome {
log::info!("Got welcome: {}", welcome);
}
Expand Down Expand Up @@ -150,7 +151,8 @@ pub async fn test_4096_file_rust2rust() -> eyre::Result<()> {
let code = code_rx.await?;
log::info!("Got code over local: {}", &code);
let (welcome, wormhole) =
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?;
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true)
.await?;
if let Some(welcome) = &welcome.welcome {
log::info!("Got welcome: {}", welcome);
}
Expand Down Expand Up @@ -223,7 +225,8 @@ pub async fn test_empty_file_rust2rust() -> eyre::Result<()> {
let code = code_rx.await?;
log::info!("Got code over local: {}", &code);
let (welcome, wormhole) =
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code).await?;
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code, true)
.await?;
if let Some(welcome) = &welcome.welcome {
log::info!("Got welcome: {}", welcome);
}
Expand Down Expand Up @@ -302,6 +305,7 @@ pub async fn test_send_many() -> eyre::Result<()> {
let (_welcome, wormhole) = Wormhole::connect_with_code(
transfer::APP_CONFIG.id(TEST_APPID),
sender_code.clone(),
false,
)
.await?;
senders.push(async_std::task::spawn(async move {
Expand Down Expand Up @@ -329,7 +333,8 @@ pub async fn test_send_many() -> eyre::Result<()> {
for i in 0..5usize {
log::info!("Receiving file #{}", i);
let (_welcome, wormhole) =
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code.clone()).await?;
Wormhole::connect_with_code(transfer::APP_CONFIG.id(TEST_APPID), code.clone(), true)
.await?;
log::info!("Got key: {}", &wormhole.key);
let req = crate::transfer::request_file(
wormhole,
Expand Down Expand Up @@ -389,6 +394,7 @@ pub async fn test_wrong_code() -> eyre::Result<()> {
APP_CONFIG,
/* Making a wrong code here by appending bullshit */
Code::new(&nameplate, "foo-bar"),
true,
)
.await;

Expand All @@ -411,9 +417,9 @@ pub async fn test_crowded() -> eyre::Result<()> {
let (welcome, connector1) = Wormhole::connect_without_code(APP_CONFIG, 2).await?;
log::info!("This test's code is: {}", &welcome.code);

let connector2 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone());
let connector2 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone(), true);

let connector3 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone());
let connector3 = Wormhole::connect_with_code(APP_CONFIG, welcome.code.clone(), true);

match futures::try_join!(connector1, connector2, connector3).unwrap_err() {
magic_wormhole::WormholeError::ServerError(
Expand All @@ -427,6 +433,26 @@ pub async fn test_crowded() -> eyre::Result<()> {
Ok(())
}

#[async_std::test]
pub async fn test_connect_with_code_expecting_nameplate() -> eyre::Result<()> {
// the max nameplate number is 999, so this will not impact a real nameplate
let code = Code("1000-guitarist-revenge".to_owned());
let connector = Wormhole::connect_with_code(APP_CONFIG, code, true)
.await
.unwrap_err();
match connector {
magic_wormhole::WormholeError::UnclaimedNameplate(x) => {
assert_eq!(x, magic_wormhole::core::Nameplate("1000".to_owned()));
},
other => panic!(
"Got wrong error type {:?}. Expected `NameplateNotFound`",
other
),
}

Ok(())
}

#[test]
fn test_phase() {
let p = Phase::PAKE;
Expand Down

0 comments on commit 46eceb0

Please sign in to comment.