Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor shared code in port forwarding into traits #249

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ use vopono_core::config::providers::{UiClient, VpnProvider};
use vopono_core::config::vpn::{verify_auth, Protocol};
use vopono_core::network::application_wrapper::ApplicationWrapper;
use vopono_core::network::firewall::Firewall;
use vopono_core::network::natpmpc::Natpmpc;
use vopono_core::network::netns::NetworkNamespace;
use vopono_core::network::network_interface::{get_active_interfaces, NetworkInterface};
use vopono_core::network::piapf::Piapf;
use vopono_core::network::port_forwarding::natpmpc::Natpmpc;
use vopono_core::network::port_forwarding::piapf::Piapf;
use vopono_core::network::port_forwarding::Forwarder;
use vopono_core::network::shadowsocks::uses_shadowsocks;
use vopono_core::network::sysctl::SysCtl;
use vopono_core::network::Forwarder;
use vopono_core::util::vopono_dir;
use vopono_core::util::{get_config_file_protocol, get_config_from_alias};
use vopono_core::util::{get_existing_namespaces, get_target_subnet};
Expand Down Expand Up @@ -154,7 +154,6 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
};

// Custom port forwarding (implementation to use for --custom-config)
// TODO: Allow fully custom handling separate callback script?
let custom_port_forwarding: Option<VpnProvider> = command
.custom_port_forwarding
.map(|x| x.to_variant())
Expand All @@ -165,7 +164,7 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
.ok()
});
if custom_port_forwarding.is_some() && custom_config.is_none() {
warn!("Custom port forwarding implementation is set, but not using custom provider config file. custom-port-forwarding setting will be ignored");
error!("Custom port forwarding implementation is set, but not using custom provider config file. custom-port-forwarding setting will be ignored");
}

// Create netns only
Expand Down Expand Up @@ -622,17 +621,17 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
Some(VpnProvider::ProtonVPN) => {
vopono_core::util::open_hosts(
&ns,
vec![vopono_core::network::natpmpc::PROTONVPN_GATEWAY],
vec![vopono_core::network::port_forwarding::natpmpc::PROTONVPN_GATEWAY],
firewall,
)?;
Some(Box::new(Natpmpc::new(&ns, callback.as_ref())?))
}
Some(p) => {
warn!("Port forwarding not supported for the selected provider: {} - ignoring --port-forwarding", p);
error!("Port forwarding not supported for the selected provider: {} - ignoring --port-forwarding", p);
None
}
None => {
warn!("--port-forwarding set but --custom-port-forwarding provider not provided for --custom-config usage. Ignoring --port-forwarding");
error!("--port-forwarding set but --custom-port-forwarding provider not provided for --custom-config usage. Ignoring --port-forwarding");
None
}
}
Expand Down
2 changes: 1 addition & 1 deletion vopono_core/src/network/application_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;

use super::{netns::NetworkNamespace, Forwarder};
use super::{netns::NetworkNamespace, port_forwarding::Forwarder};
use crate::util::get_all_running_process_names;
use log::warn;

Expand Down
7 changes: 1 addition & 6 deletions vopono_core/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@ pub mod application_wrapper;
pub mod dns_config;
pub mod firewall;
pub mod host_masquerade;
pub mod natpmpc;
pub mod netns;
pub mod network_interface;
pub mod openconnect;
pub mod openfortivpn;
pub mod openvpn;
pub mod piapf;
pub mod port_forwarding;
pub mod shadowsocks;
pub mod sysctl;
pub mod veth_pair;
pub mod warp;
pub mod wireguard;

pub trait Forwarder {
fn forwarded_port(&self) -> u16;
}
73 changes: 73 additions & 0 deletions vopono_core/src/network/port_forwarding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::sync::mpsc::Receiver;

use super::netns::NetworkNamespace;

pub mod natpmpc;
pub mod piapf;

pub trait Forwarder {
fn forwarded_port(&self) -> u16;
}

/// ThreadParams must implement these methods
pub trait ThreadParameters {
fn get_callback_command(&self) -> Option<String>;
fn get_loop_delay(&self) -> u64;
fn get_netns_name(&self) -> String;
}

pub trait ThreadLoopForwarder: Forwarder {
/// Implementation defines parameter struct passed to loop on thread
type ThreadParams: ThreadParameters;

/// Implementation defines how to refresh port
fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16>;

/// Provided common implementation for thread loop
fn thread_loop(params: Self::ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(params.get_loop_delay()));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => {
log::debug!("Thread refreshed port: {p}");
Self::callback_command(&params, p);
}
}
}
}
}

fn callback_command(params: &Self::ThreadParams, port: u16) -> Option<anyhow::Result<String>> {
params.get_callback_command().map(|callback_command|
{
let refresh_response = NetworkNamespace::exec_with_output(
&params.get_netns_name(),
&[&callback_command, &port.to_string()],
)?;
if !refresh_response.status.success() {
log::error!(
"Port forwarding callback script was unsuccessful!: stdout: {:?}, stderr: {:?}, exit code: {}",
String::from_utf8(refresh_response.stdout),
String::from_utf8(refresh_response.stderr),
refresh_response.status
);
Err(anyhow::anyhow!("Port forwarding callback script failed"))
} else if let Ok(out) = String::from_utf8(refresh_response.stdout) {
println!("{}", out);
Ok(out)
} else {
Ok("Callback script succeeded but stdout was not valid UTF8".to_string())
}
}
)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use anyhow::Context;
use regex::Regex;
use std::sync::mpsc::{self, Receiver};
use std::sync::mpsc;
use std::{
net::{IpAddr, Ipv4Addr},
sync::mpsc::Sender,
thread::JoinHandle,
};

use super::netns::NetworkNamespace;
use super::Forwarder;
use super::{Forwarder, ThreadLoopForwarder, ThreadParameters};
use crate::network::netns::NetworkNamespace;

// TODO: Move this to ProtonVPN provider
pub const PROTONVPN_GATEWAY: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 2, 0, 1));
Expand All @@ -20,11 +20,23 @@ pub struct Natpmpc {
send_channel: Sender<bool>,
}

struct ThreadParams {
pub struct ThreadParamsImpl {
pub netns_name: String,
pub callback: Option<String>,
}

impl ThreadParameters for ThreadParamsImpl {
fn get_callback_command(&self) -> Option<String> {
self.callback.clone()
}
fn get_loop_delay(&self) -> u64 {
45
}
fn get_netns_name(&self) -> String {
self.netns_name.clone()
}
}

impl Natpmpc {
pub fn new(ns: &NetworkNamespace, callback: Option<&String>) -> anyhow::Result<Self> {
let gateway_str = PROTONVPN_GATEWAY.to_string();
Expand All @@ -49,11 +61,13 @@ impl Natpmpc {
anyhow::bail!("natpmpc failed - likely that this server does not support port forwarding, please choose another server")
}

let params = ThreadParams {
let params = ThreadParamsImpl {
netns_name: ns.name.clone(),
callback: callback.cloned(),
};

let port = Self::refresh_port(&params)?;
Self::callback_command(&params, port);

let (send, recv) = mpsc::channel::<bool>();

Expand All @@ -66,9 +80,12 @@ impl Natpmpc {
send_channel: send,
})
}
}

// TODO: Refactor these two methods into Trait shared with piapf.rs
fn refresh_port(params: &ThreadParams) -> anyhow::Result<u16> {
impl ThreadLoopForwarder for Natpmpc {
type ThreadParams = ThreadParamsImpl;

fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16> {
let gateway_str = PROTONVPN_GATEWAY.to_string();
// TODO: Cache regex
let re = Regex::new(r"Mapped public port (?P<port>\d{1,5}) protocol").unwrap();
Expand Down Expand Up @@ -102,48 +119,8 @@ impl Natpmpc {
"natpmpc assigned UDP port: {udp_port} did not equal TCP port: {tcp_port}"
)
}

if let Some(cb) = &params.callback {
let refresh_response = NetworkNamespace::exec_with_output(
&params.netns_name,
&[cb, &udp_port.to_string()],
)?;
if !refresh_response.status.success() {
log::error!(
"Port forwarding callback script was unsuccessful!: stdout: {:?}, stderr: {:?}, exit code: {}",
String::from_utf8(refresh_response.stdout),
String::from_utf8(refresh_response.stderr),
refresh_response.status
);
} else if let Ok(out) = String::from_utf8(refresh_response.stdout) {
println!("{}", out);
}
}

Ok(udp_port)
}

// Spawn thread to repeat above every 45 seconds
fn thread_loop(params: ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(45));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => log::debug!("Thread refreshed port: {p}"),
}

// TODO: Communicate port change via channel?
}
}
}
}

impl Drop for Natpmpc {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use base64::prelude::*;
use regex::Regex;
use std::sync::mpsc::{self, Receiver};
use std::sync::mpsc::{self};
use std::{sync::mpsc::Sender, thread::JoinHandle};
use which::which;

use super::netns::NetworkNamespace;
use super::Forwarder;
use super::{Forwarder, ThreadLoopForwarder, ThreadParameters};
use crate::network::netns::NetworkNamespace;

use crate::config::providers::pia::PrivateInternetAccess;
use crate::config::providers::OpenVpnProvider;
Expand All @@ -18,7 +18,7 @@ pub struct Piapf {
send_channel: Sender<bool>,
}

struct ThreadParams {
pub struct ThreadParamsImpl {
pub port: u16,
pub netns_name: String,
pub signature: String,
Expand All @@ -29,6 +29,20 @@ struct ThreadParams {
pub callback: Option<String>,
}

impl ThreadParameters for ThreadParamsImpl {
fn get_callback_command(&self) -> Option<String> {
self.callback.clone()
}

fn get_loop_delay(&self) -> u64 {
60 * 15
}

fn get_netns_name(&self) -> String {
self.netns_name.clone()
}
}

impl Piapf {
pub fn new(
ns: &NetworkNamespace,
Expand Down Expand Up @@ -147,7 +161,7 @@ impl Piapf {
.as_u16()
.expect("getSignature response missing port");

let params = ThreadParams {
let params = ThreadParamsImpl {
netns_name: ns.name.clone(),
hostname: vpn_hostname,
gateway: vpn_gateway,
Expand All @@ -157,7 +171,8 @@ impl Piapf {
port,
callback: callback.cloned(),
};
Self::refresh_port(&params)?;
let port = Self::refresh_port(&params)?;
Self::callback_command(&params, port);
let (send, recv) = mpsc::channel::<bool>();
let handle = std::thread::spawn(move || Self::thread_loop(params, recv));

Expand All @@ -168,9 +183,12 @@ impl Piapf {
send_channel: send,
})
}
}

impl ThreadLoopForwarder for Piapf {
type ThreadParams = ThreadParamsImpl;

// TODO: Refactor methods below into Trait
fn refresh_port(params: &ThreadParams) -> anyhow::Result<u16> {
fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16> {
let bind_response = NetworkNamespace::exec_with_output(
&params.netns_name,
&[
Expand Down Expand Up @@ -222,28 +240,6 @@ impl Piapf {

Ok(params.port)
}

// Spawn thread to repeat above every 15 minutes
fn thread_loop(params: ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(60 * 15));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => log::debug!("Thread refreshed port: {p}"),
}

// TODO: Communicate port change via channel?
}
}
}
}

impl Drop for Piapf {
Expand Down
Loading