Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
contrun committed Jan 10, 2025
1 parent ab91390 commit fdf4a12
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
23 changes: 21 additions & 2 deletions src/fiber/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,22 @@ pub enum ChannelCommand {
Update(UpdateCommand, RpcReplyPort<Result<(), String>>),
ForwardTlcResult(ForwardTlcResult),
#[cfg(test)]
ReloadState(),
ReloadState(ReloadParams),
}

#[cfg(test)]
#[derive(Debug)]
pub struct ReloadParams {
pub notify_changes: bool,
}

#[cfg(test)]
impl Default for ReloadParams {
fn default() -> Self {
Self {
notify_changes: true,
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -1827,11 +1842,15 @@ where
Ok(())
}
#[cfg(test)]
ChannelCommand::ReloadState() => {
ChannelCommand::ReloadState(reload_params) => {
*state = self
.store
.get_channel_actor_state(&state.get_id())
.expect("load channel state failed");
let ReloadParams { notify_changes } = reload_params;
if notify_changes {
state.notify_owned_channel_updated(&self.network).await;
}
Ok(())
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/fiber/tests/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5002,8 +5002,8 @@ async fn test_send_payment_with_disable_channel() {
// sleep for a while
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;

// begin to set channel disable
node_2.disable_channel(channels[1]).await;
// begin to set channel disable, but do not notify the network
node_2.disable_channel_stealthy(channels[1]).await;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;

let message = |rpc_reply| -> NetworkActorMessage {
Expand Down
30 changes: 25 additions & 5 deletions src/fiber/tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::fiber::channel::ChannelActorState;
use crate::fiber::channel::ChannelActorStateStore;
use crate::fiber::channel::ChannelCommand;
use crate::fiber::channel::ChannelCommandWithId;
use crate::fiber::channel::ReloadParams;
use crate::fiber::graph::NetworkGraphStateStore;
use crate::fiber::graph::PaymentSession;
use crate::fiber::graph::PaymentSessionStatus;
Expand Down Expand Up @@ -675,14 +676,18 @@ impl NetworkNode {
res
}

pub async fn update_channel_actor_state(&mut self, state: ChannelActorState) {
pub async fn update_channel_actor_state(
&mut self,
state: ChannelActorState,
reload_params: Option<ReloadParams>,
) {
let channel_id = state.id.clone();
self.store.insert_channel_actor_state(state);
self.network_actor
.send_message(NetworkActorMessage::Command(
NetworkActorCommand::ControlFiberChannel(ChannelCommandWithId {
channel_id,
command: ChannelCommand::ReloadState(),
command: ChannelCommand::ReloadState(reload_params.unwrap_or_default()),
}),
))
.expect("network actor is live");
Expand All @@ -696,7 +701,8 @@ impl NetworkNode {
) {
let mut channel_actor_state = self.get_channel_actor_state(channel_id);
channel_actor_state.to_local_amount = new_to_local_amount;
self.update_channel_actor_state(channel_actor_state).await;
self.update_channel_actor_state(channel_actor_state, None)
.await;
}

pub async fn update_channel_remote_balance(
Expand All @@ -706,13 +712,27 @@ impl NetworkNode {
) {
let mut channel_actor_state = self.get_channel_actor_state(channel_id);
channel_actor_state.to_remote_amount = new_to_remote_amount;
self.update_channel_actor_state(channel_actor_state).await;
self.update_channel_actor_state(channel_actor_state, None)
.await;
}

pub async fn disable_channel(&mut self, channel_id: Hash256) {
let mut channel_actor_state = self.get_channel_actor_state(channel_id);
channel_actor_state.local_tlc_info.enabled = false;
self.update_channel_actor_state(channel_actor_state).await;
self.update_channel_actor_state(channel_actor_state, None)
.await;
}

pub async fn disable_channel_stealthy(&mut self, channel_id: Hash256) {
let mut channel_actor_state = self.get_channel_actor_state(channel_id);
channel_actor_state.local_tlc_info.enabled = false;
self.update_channel_actor_state(
channel_actor_state,
Some(ReloadParams {
notify_changes: false,
}),
)
.await;
}

pub fn get_payment_session(&self, payment_hash: Hash256) -> Option<PaymentSession> {
Expand Down

0 comments on commit fdf4a12

Please sign in to comment.