Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wait for routes #7659

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class TalpidVpnServiceFallbackDnsTest {
every { talpidVpnService.prepareVpnSafe() } returns Prepared.right()
builderMockk = mockk<VpnService.Builder>()

every { talpidVpnService getProperty "connectivityListener" } returns
mockk<ConnectivityListener>(relaxed = true)

mockkConstructor(VpnService.Builder::class)
every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk
every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@ package net.mullvad.talpid

import android.net.ConnectivityManager
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import co.touchlab.kermit.Logger
import java.net.InetAddress
import kotlin.collections.ArrayList
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.receiveAsFlow
import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.runBlocking
import net.mullvad.talpid.model.NetworkState
import net.mullvad.talpid.util.NetworkEvent
import net.mullvad.talpid.util.RawNetworkState
Expand All @@ -30,6 +32,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
get() = _isConnected.value

private lateinit var _currentNetworkState: StateFlow<NetworkState?>
private val resetNetworkState: Channel<Unit> = Channel()

// Used by JNI
val currentDefaultNetworkState: NetworkState?
Expand All @@ -44,51 +47,70 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
// the default network may fail if the network on Android 11
// https://issuetracker.google.com/issues/175055271?pli=1
_currentNetworkState =
connectivityManager
.defaultRawNetworkStateFlow()
merge(
connectivityManager.defaultRawNetworkStateFlow(),
resetNetworkState.receiveAsFlow().map { null },
)
.map { it?.toNetworkState() }
.onEach { notifyDefaultNetworkChange(it) }
.stateIn(scope, SharingStarted.Eagerly, null)

_isConnected =
hasInternetCapability()
.onEach { notifyConnectivityChange(it) }
.stateIn(scope, SharingStarted.Eagerly, false)
.stateIn(
scope,
SharingStarted.Eagerly,
true, // Assume we have internet until we know otherwise
)
}

/**
* Invalidates the network state cache. E.g when the VPN is connected or disconnected, and we
* know the last known values not to be correct anymore.
*/
fun invalidateNetworkStateCache() {
// TODO remove runBlocking
runBlocking { resetNetworkState.send(Unit) }
}

private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> =
dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }

private fun hasInternetCapability(): Flow<Boolean> {
val request =
NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
.build()
private val nonVPNNetworksRequest =
NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build()

private fun hasInternetCapability(): Flow<Boolean> {
@Suppress("DEPRECATION")
return connectivityManager
.networkEvents(request)
.scan(setOf<Network>()) { networks, event ->
.networkEvents(nonVPNNetworksRequest)
.scan(
connectivityManager.allNetworks.associateWith {
connectivityManager.getNetworkCapabilities(it)
}
) { networks, event ->
when (event) {
is NetworkEvent.Available -> {
Logger.d("Network available ${event.network}")
(networks + event.network).also {
Logger.d("Number of networks: ${it.size}")
}
}
is NetworkEvent.Lost -> {
Logger.d("Network lost ${event.network}")
(networks - event.network).also {
Logger.d("Number of networks: ${it.size}")
}
}
is NetworkEvent.CapabilitiesChanged -> {
Logger.d("Network capabilities changed ${event.network}")
(networks + (event.network to event.networkCapabilities)).also {
Logger.d("Number of networks: ${it.size}")
}
}
else -> networks
}
}
.map { it.isNotEmpty() }
.distinctUntilChanged()
.map { it.any { it.value.hasInternetCapability() } }
}

private fun NetworkCapabilities?.hasInternetCapability(): Boolean =
this?.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) == true

private fun RawNetworkState.toNetworkState(): NetworkState =
NetworkState(
network.networkHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,22 @@ open class TalpidVpnService : LifecycleVpnService() {
// Used by JNI
fun openTun(config: TunConfig): CreateTunResult =
synchronized(this) {
val tunStatus = activeTunStatus

if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) {
tunStatus
} else {
openTunImpl(config)
createTun(config).merge().also {
currentTunConfig = config
activeTunStatus = it
}
}

// Used by JNI
fun openTunForced(config: TunConfig): CreateTunResult =
synchronized(this) { openTunImpl(config) }

// Used by JNI
fun closeTun(): Unit = synchronized(this) { activeTunStatus = null }
fun closeTun(): Unit =
synchronized(this) {
connectivityListener.invalidateNetworkStateCache()
activeTunStatus = null
}

// Used by JNI
fun bypass(socket: Int): Boolean = protect(socket)

private fun openTunImpl(config: TunConfig): CreateTunResult {
val newTunStatus = createTun(config).merge()

currentTunConfig = config
activeTunStatus = newTunStatus

return newTunStatus
}

private fun createTun(
config: TunConfig
): Either<CreateTunResult.Error, CreateTunResult.Success> = either {
Expand Down Expand Up @@ -123,6 +111,7 @@ open class TalpidVpnService : LifecycleVpnService() {
builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER)
}

connectivityListener.invalidateNetworkStateCache()
val vpnInterfaceFd =
builder
.establishSafe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,19 @@ fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow<Netw
}

internal fun ConnectivityManager.defaultRawNetworkStateFlow(): Flow<RawNetworkState?> =
defaultNetworkEvents()
.scan(
null as RawNetworkState?,
{ state, event ->
return@scan when (event) {
is NetworkEvent.Available -> RawNetworkState(network = event.network)
is NetworkEvent.BlockedStatusChanged ->
state?.copy(blockedStatus = event.blocked)
is NetworkEvent.CapabilitiesChanged ->
state?.copy(networkCapabilities = event.networkCapabilities)
is NetworkEvent.LinkPropertiesChanged ->
state?.copy(linkProperties = event.linkProperties)
is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive)
is NetworkEvent.Lost -> null
NetworkEvent.Unavailable -> null
}
},
)
defaultNetworkEvents().scan(null as RawNetworkState?) { state, event -> state.reduce(event) }

private fun RawNetworkState?.reduce(event: NetworkEvent): RawNetworkState? =
when (event) {
is NetworkEvent.Available -> RawNetworkState(network = event.network)
is NetworkEvent.BlockedStatusChanged -> this?.copy(blockedStatus = event.blocked)
is NetworkEvent.CapabilitiesChanged ->
this?.copy(networkCapabilities = event.networkCapabilities)
is NetworkEvent.LinkPropertiesChanged -> this?.copy(linkProperties = event.linkProperties)
is NetworkEvent.Losing -> this?.copy(maxMsToLive = event.maxMsToLive)
is NetworkEvent.Lost -> null
NetworkEvent.Unavailable -> null
}

sealed interface NetworkEvent {
data class Available(val network: Network) : NetworkEvent
Expand Down
41 changes: 4 additions & 37 deletions talpid-core/src/tunnel_state_machine/connected_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use futures::channel::{mpsc, oneshot};
use futures::stream::Fuse;
use futures::StreamExt;

#[cfg(target_os = "android")]
use talpid_tunnel::tun_provider::Error;
use talpid_types::net::{AllowedClients, AllowedEndpoint, TunnelParameters};
use talpid_types::tunnel::{ErrorStateCause, FirewallPolicyError};
use talpid_types::{BoxedError, ErrorExt};
Expand Down Expand Up @@ -260,14 +258,7 @@ impl ConnectedState {
let consequence = if shared_values.set_allow_lan(allow_lan) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
)
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
{
Expand Down Expand Up @@ -298,22 +289,7 @@ impl ConnectedState {
let consequence = if shared_values.set_dns_config(servers) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
match _err {
Error::InvalidDnsServers(ip_addrs) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::InvalidDnsServers(
ip_addrs,
)),
),
_ => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
),
}
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
{
Expand Down Expand Up @@ -385,17 +361,8 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
if shared_values.set_excluded_paths(paths) {
if let Err(err) = shared_values.restart_tunnel(false) {
let _ =
result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err)));
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SplitTunnelError),
)
} else {
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
let _ = result_tx.send(Ok(()));
SameState(self)
Expand Down
50 changes: 8 additions & 42 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ use crate::tunnel::{self, TunnelMonitor};

pub(crate) type TunnelCloseEvent = Fuse<oneshot::Receiver<Option<ErrorStateCause>>>;

#[cfg(target_os = "android")]
const MAX_ATTEMPTS_WITH_SAME_TUN: u32 = 5;
const MIN_TUNNEL_ALIVE_TIME: Duration = Duration::from_millis(1000);
#[cfg(target_os = "windows")]
const MAX_ATTEMPT_CREATE_TUN: u32 = 4;
Expand Down Expand Up @@ -114,20 +112,11 @@ impl ConnectingState {
ErrorStateCause::SetFirewallPolicyError(error),
)
} else {
// This is magically shimmed in on the side on Android to prep the TunConfig
// with the right DNS servers. On Android DNS is part of creating the VPN
// interface and this call should be part of start_tunnel call instead
#[cfg(target_os = "android")]
{
shared_values.prepare_tun_config(false);
if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 {
if let Err(error) =
{ shared_values.tun_provider.lock().unwrap().open_tun_forced() }
{
log::error!(
"{}",
error.display_chain_with_msg("Failed to recreate tun device")
);
}
}
}
shared_values.prepare_tun_config(false);

let connecting_state = Self::start_tunnel(
shared_values.runtime.clone(),
Expand Down Expand Up @@ -386,14 +375,7 @@ impl ConnectingState {
let consequence = if shared_values.set_allow_lan(allow_lan) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
)
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
self.reset_firewall(shared_values)
Expand Down Expand Up @@ -427,14 +409,7 @@ impl ConnectingState {
let consequence = if shared_values.set_dns_config(servers) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
)
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
SameState(self)
Expand Down Expand Up @@ -484,17 +459,8 @@ impl ConnectingState {
#[cfg(target_os = "android")]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
if shared_values.set_excluded_paths(paths) {
if let Err(err) = shared_values.restart_tunnel(false) {
let _ =
result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err)));
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SplitTunnelError),
)
} else {
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
let _ = result_tx.send(Ok(()));
SameState(self)
Expand Down
Loading
Loading