From 43cbbb50f9bab8dece273e764d56a2c467032ac0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20L=C3=B6nnhager?= <david.l@mullvad.net>
Date: Wed, 29 Jan 2025 15:11:39 +0100
Subject: [PATCH] Track IPv6 connectivity on Android
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Jonatan Rhoidn <jonatan.rhodin@mullvad.net>
Co-authored-by: David Göransson <david.goransson@mullvad.net>
---
 android/CHANGELOG.md                          |   3 +
 .../util/ConnectivityManagerUtilKtTest.kt     | 215 +++++++++++++-----
 .../mullvad/talpid/ConnectivityListener.kt    |  33 ++-
 .../net/mullvad/talpid/TalpidVpnService.kt    |   7 +-
 .../net/mullvad/talpid/model/Connectivity.kt  |   8 +
 .../talpid/util/ConnectivityManagerUtil.kt    | 163 +++++--------
 .../UnderlyingConnectivityStatusResolver.kt   |  69 ++++++
 mullvad-jni/src/classes.rs                    |   2 +
 talpid-core/src/connectivity_listener.rs      |  62 ++---
 talpid-types/src/net/mod.rs                   |  27 +--
 10 files changed, 359 insertions(+), 230 deletions(-)
 create mode 100644 android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/Connectivity.kt
 create mode 100644 android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/UnderlyingConnectivityStatusResolver.kt

diff --git a/android/CHANGELOG.md b/android/CHANGELOG.md
index 8e84283b32e3..bbc43c3bf23d 100644
--- a/android/CHANGELOG.md
+++ b/android/CHANGELOG.md
@@ -34,6 +34,9 @@ Line wrap the file at 100 chars.                                              Th
 ### Removed
 - Remove Google's resolvers from encrypted DNS proxy.
 
+### Fixed
+- Will no longer try to connect over IPv6 if IPv6 is not available.
+
 
 ## [android/2024.10-beta2] - 2024-12-20
 
diff --git a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/talpid/util/ConnectivityManagerUtilKtTest.kt b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/talpid/util/ConnectivityManagerUtilKtTest.kt
index e8ccd7fdf69c..354b6f585dfa 100644
--- a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/talpid/util/ConnectivityManagerUtilKtTest.kt
+++ b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/talpid/util/ConnectivityManagerUtilKtTest.kt
@@ -1,11 +1,17 @@
 package net.mullvad.mullvadvpn.talpid.util
 
 import android.net.ConnectivityManager
+import android.net.LinkAddress
+import android.net.LinkProperties
 import android.net.Network
+import android.net.NetworkCapabilities
 import app.cash.turbine.test
 import io.mockk.every
 import io.mockk.mockk
 import io.mockk.mockkStatic
+import io.mockk.verify
+import java.net.Inet4Address
+import java.net.Inet6Address
 import kotlin.test.assertEquals
 import kotlin.time.Duration.Companion.milliseconds
 import kotlin.time.Duration.Companion.seconds
@@ -13,10 +19,11 @@ import kotlinx.coroutines.channels.awaitClose
 import kotlinx.coroutines.delay
 import kotlinx.coroutines.flow.callbackFlow
 import kotlinx.coroutines.test.runTest
+import net.mullvad.talpid.model.Connectivity
 import net.mullvad.talpid.util.NetworkEvent
+import net.mullvad.talpid.util.UnderlyingConnectivityStatusResolver
+import net.mullvad.talpid.util.defaultNetworkEvents
 import net.mullvad.talpid.util.hasInternetConnectivity
-import net.mullvad.talpid.util.networkEvents
-import net.mullvad.talpid.util.networksWithInternetConnectivity
 import org.junit.jupiter.api.BeforeEach
 import org.junit.jupiter.api.Test
 
@@ -31,19 +38,22 @@ class ConnectivityManagerUtilKtTest {
     /** User being online, the listener should emit once with `true` */
     @Test
     fun userIsOnline() = runTest {
-        val network = mockk<Network>()
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf(network)
-        every { connectivityManager.networkEvents(any()) } returns
+        val network = mockk<Network>(relaxed = true)
+        val linkProperties = mockLinkProperties(true, true)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 delay(100.milliseconds) // Simulate connectivity listener being a bit slow
                 send(NetworkEvent.Available(network))
+                delay(100.milliseconds) // Simulate connectivity listener being a bit slow
+                send(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
             // Since initial state and listener both return `true` within debounce we only see one
             // event
-            assertEquals(true, awaitItem())
+            assertEquals(Connectivity.Status(true, true), awaitItem())
             expectNoEvents()
         }
     }
@@ -51,12 +61,12 @@ class ConnectivityManagerUtilKtTest {
     /** User being offline, the listener should emit once with `false` */
     @Test
     fun userIsOffline() = runTest {
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf()
-        every { connectivityManager.networkEvents(any()) } returns callbackFlow { awaitClose {} }
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+        every { connectivityManager.defaultNetworkEvents() } returns callbackFlow { awaitClose {} }
 
-        connectivityManager.hasInternetConnectivity().test {
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
             // Initially offline and no network events, so we should get a single `false` event
-            assertEquals(false, awaitItem())
+            assertEquals(Connectivity.Status(false, false), awaitItem())
             expectNoEvents()
         }
     }
@@ -64,19 +74,22 @@ class ConnectivityManagerUtilKtTest {
     /** User starting offline and then turning on a online after a while */
     @Test
     fun initiallyOfflineThenBecomingOnline() = runTest {
-        every { connectivityManager.networksWithInternetConnectivity() } returns emptySet()
-        every { connectivityManager.networkEvents(any()) } returns
+        val network = mockk<Network>()
+        val linkProperties = mockLinkProperties(true, true)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 // Simulate offline for a little while
                 delay(5.seconds)
                 // Then become online
                 send(NetworkEvent.Available(mockk()))
+                send(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
-            assertEquals(false, awaitItem())
-            assertEquals(true, awaitItem())
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
+            assertEquals(Connectivity.Status(false, false), awaitItem())
+            assertEquals(Connectivity.Status(true, true), awaitItem())
             expectNoEvents()
         }
     }
@@ -85,46 +98,23 @@ class ConnectivityManagerUtilKtTest {
     @Test
     fun initiallyOnlineAndThenTurningBecomingOffline() = runTest {
         val network = mockk<Network>()
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf(network)
-        every { connectivityManager.networkEvents(any()) } returns
+        val linkProperties = mockLinkProperties(true, true)
+
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 // Starting as online
                 send(NetworkEvent.Available(network))
+                send(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
                 delay(5.seconds)
                 // Then becoming offline
                 send(NetworkEvent.Lost(network))
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
-            assertEquals(true, awaitItem())
-            assertEquals(false, awaitItem())
-            expectNoEvents()
-        }
-    }
-
-    /**
-     * User turning on Airplane mode as our connectivity listener starts so we never get any
-     * onAvailable event from our listener. Initial value will be `true`, followed by no
-     * `networkEvent` and then turning on network again after 5 seconds
-     */
-    @Test
-    fun incorrectInitialValueThenBecomingOnline() = runTest {
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf(mockk())
-        every { connectivityManager.networkEvents(any()) } returns
-            callbackFlow {
-                delay(5.seconds)
-                send(NetworkEvent.Available(mockk()))
-                awaitClose {}
-            }
-
-        connectivityManager.hasInternetConnectivity().test {
-            // Initial value is connected
-            assertEquals(true, awaitItem())
-            // Debounce time has passed, and we never received any network events, so we are offline
-            assertEquals(false, awaitItem())
-            // Network is back online
-            assertEquals(true, awaitItem())
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
+            assertEquals(Connectivity.Status(true, true), awaitItem())
+            assertEquals(Connectivity.Status(false, false), awaitItem())
             expectNoEvents()
         }
     }
@@ -133,26 +123,34 @@ class ConnectivityManagerUtilKtTest {
     @Test
     fun roamingFromCellularToWifi() = runTest {
         val wifiNetwork = mockk<Network>()
+        val wifiNetworkLinkProperties = mockLinkProperties(true, false)
         val cellularNetwork = mockk<Network>()
+        val cellularNetworkLinkProperties = mockLinkProperties(true, false)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
 
-        every { connectivityManager.networksWithInternetConnectivity() } returns
-            setOf(cellularNetwork)
-        every { connectivityManager.networkEvents(any()) } returns
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 send(NetworkEvent.Available(cellularNetwork))
+                send(
+                    NetworkEvent.LinkPropertiesChanged(
+                        cellularNetwork,
+                        cellularNetworkLinkProperties,
+                    )
+                )
                 delay(5.seconds)
                 // Turning on WiFi, we'll have duplicate networks until phone decides to turn of
                 // cellular
                 send(NetworkEvent.Available(wifiNetwork))
+                send(NetworkEvent.LinkPropertiesChanged(wifiNetwork, wifiNetworkLinkProperties))
                 delay(30.seconds)
                 // Phone turning off cellular network
                 send(NetworkEvent.Lost(cellularNetwork))
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
             // We should always only see us being online
-            assertEquals(true, awaitItem())
+            assertEquals(Connectivity.Status(ipv4 = true, ipv6 = false), awaitItem())
             expectNoEvents()
         }
     }
@@ -161,23 +159,32 @@ class ConnectivityManagerUtilKtTest {
     @Test
     fun roamingFromWifiToCellular() = runTest {
         val wifiNetwork = mockk<Network>()
+        val wifiNetworkLinkProperties = mockLinkProperties(true, false)
         val cellularNetwork = mockk<Network>()
+        val cellularNetworkLinkProperties = mockLinkProperties(true, false)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
 
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf(wifiNetwork)
-        every { connectivityManager.networkEvents(any()) } returns
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 send(NetworkEvent.Available(wifiNetwork))
+                send(NetworkEvent.LinkPropertiesChanged(wifiNetwork, wifiNetworkLinkProperties))
                 delay(5.seconds)
                 send(NetworkEvent.Lost(wifiNetwork))
                 // We will have no network for a little time until cellular chip is on.
                 delay(150.milliseconds)
                 send(NetworkEvent.Available(cellularNetwork))
+                send(
+                    NetworkEvent.LinkPropertiesChanged(
+                        cellularNetwork,
+                        cellularNetworkLinkProperties,
+                    )
+                )
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
             // We should always only see us being online, small offline state is caught by debounce
-            assertEquals(true, awaitItem())
+            assertEquals(Connectivity.Status(ipv4 = true, ipv6 = false), awaitItem())
             expectNoEvents()
         }
     }
@@ -186,31 +193,115 @@ class ConnectivityManagerUtilKtTest {
     @Test
     fun slowRoamingFromWifiToCellular() = runTest {
         val wifiNetwork = mockk<Network>()
+        val wifiNetworkLinkProperties = mockLinkProperties(false, true)
         val cellularNetwork = mockk<Network>()
+        val cellularNetworkLinkProperties = mockLinkProperties(false, true)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
 
-        every { connectivityManager.networksWithInternetConnectivity() } returns setOf(wifiNetwork)
-        every { connectivityManager.networkEvents(any()) } returns
+        every { connectivityManager.defaultNetworkEvents() } returns
             callbackFlow {
                 send(NetworkEvent.Available(wifiNetwork))
+                send(NetworkEvent.LinkPropertiesChanged(wifiNetwork, wifiNetworkLinkProperties))
                 delay(5.seconds)
                 send(NetworkEvent.Lost(wifiNetwork))
                 // We will have no network for a little time until cellular chip is on.
                 delay(500.milliseconds)
                 send(NetworkEvent.Available(cellularNetwork))
+                send(
+                    NetworkEvent.LinkPropertiesChanged(
+                        cellularNetwork,
+                        cellularNetworkLinkProperties,
+                    )
+                )
                 awaitClose {}
             }
 
-        connectivityManager.hasInternetConnectivity().test {
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
             // Wifi is online
-            assertEquals(true, awaitItem())
+            assertEquals(Connectivity.Status(false, true), awaitItem())
             // We didn't get any network within debounce time, so we are offline
-            assertEquals(false, awaitItem())
+            assertEquals(Connectivity.Status(false, false), awaitItem())
             // Cellular network is online
-            assertEquals(true, awaitItem())
+            assertEquals(Connectivity.Status(false, true), awaitItem())
+            expectNoEvents()
+        }
+    }
+
+    /** Switching between networks with different configurations. */
+    @Test
+    fun roamingFromWifiWithIpv6OnlyToWifiWithIpv4Only() = runTest {
+        val ipv6Network = mockk<Network>()
+        val ipv6NetworkLinkProperties = mockLinkProperties(false, true)
+        val ipv4Network = mockk<Network>()
+        val ipv4NetworkLinkProperties = mockLinkProperties(true, false)
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+
+        every { connectivityManager.defaultNetworkEvents() } returns
+            callbackFlow {
+                send(NetworkEvent.Available(ipv6Network))
+                send(NetworkEvent.LinkPropertiesChanged(ipv6Network, ipv6NetworkLinkProperties))
+                delay(5.seconds)
+                send(NetworkEvent.Lost(ipv6Network))
+                delay(100.milliseconds)
+                send(NetworkEvent.Available(ipv4Network))
+                send(NetworkEvent.LinkPropertiesChanged(ipv4Network, ipv4NetworkLinkProperties))
+                awaitClose {}
+            }
+
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
+            // Ipv6 network is online
+            assertEquals(Connectivity.Status(false, true), awaitItem())
+            // Ipv4 network is online
+            assertEquals(Connectivity.Status(true, false), awaitItem())
             expectNoEvents()
         }
     }
 
+    /** Vpn network should NOT check link properties but should rather use socket implementation */
+    @Test
+    fun checkVpnNetworkUsingSocketImplementation() = runTest {
+        val vpnNetwork = mockk<Network>()
+        val capabilities = mockk<NetworkCapabilities>()
+        every { capabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) } returns
+            false
+        val mockResolver = mockk<UnderlyingConnectivityStatusResolver>()
+        every { mockResolver.currentStatus() } returns Connectivity.Status(true, true)
+
+        every { connectivityManager.defaultNetworkEvents() } returns
+            callbackFlow {
+                send(NetworkEvent.Available(vpnNetwork))
+                send(NetworkEvent.CapabilitiesChanged(vpnNetwork, capabilities))
+                awaitClose {}
+            }
+
+        connectivityManager.hasInternetConnectivity(mockResolver).test {
+            // Network is online
+            assertEquals(Connectivity.Status(true, true), awaitItem())
+        }
+
+        verify(exactly = 1) { mockResolver.currentStatus() }
+    }
+
+    private fun mockLinkProperties(ipv4: Boolean, ipv6: Boolean) =
+        mockk<LinkProperties> {
+            val linkAddresses = buildList {
+                if (ipv4) {
+                    val linkIpv4Address: LinkAddress = mockk()
+                    val ipv4Address: Inet4Address = mockk()
+                    every { linkIpv4Address.address } returns ipv4Address
+                    add(linkIpv4Address)
+                }
+                if (ipv6) {
+                    val linkIpv6Address: LinkAddress = mockk()
+                    val ipv6Address: Inet6Address = mockk()
+                    every { linkIpv6Address.address } returns ipv6Address
+                    add(linkIpv6Address)
+                }
+            }
+
+            every { this@mockk.linkAddresses } returns linkAddresses
+        }
+
     companion object {
         private const val CONNECTIVITY_MANAGER_UTIL_CLASS =
             "net.mullvad.talpid.util.ConnectivityManagerUtilKt"
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
index c918d762ff0b..ede883a83706 100644
--- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt
@@ -5,6 +5,8 @@ import android.net.LinkProperties
 import java.net.InetAddress
 import kotlin.collections.ArrayList
 import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.FlowPreview
 import kotlinx.coroutines.channels.Channel
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.SharingStarted
@@ -15,13 +17,22 @@ import kotlinx.coroutines.flow.onEach
 import kotlinx.coroutines.flow.receiveAsFlow
 import kotlinx.coroutines.flow.stateIn
 import kotlinx.coroutines.launch
+import kotlinx.coroutines.plus
+import kotlinx.coroutines.runBlocking
+import net.mullvad.talpid.model.Connectivity
 import net.mullvad.talpid.model.NetworkState
 import net.mullvad.talpid.util.RawNetworkState
+import net.mullvad.talpid.util.UnderlyingConnectivityStatusResolver
+import net.mullvad.talpid.util.activeRawNetworkState
 import net.mullvad.talpid.util.defaultRawNetworkStateFlow
 import net.mullvad.talpid.util.hasInternetConnectivity
+import net.mullvad.talpid.util.resolveConnectivityStatus
 
-class ConnectivityListener(private val connectivityManager: ConnectivityManager) {
-    private lateinit var _isConnected: StateFlow<Boolean>
+class ConnectivityListener(
+    private val connectivityManager: ConnectivityManager,
+    private val resolver: UnderlyingConnectivityStatusResolver,
+) {
+    private lateinit var _isConnected: StateFlow<Connectivity>
     // Used by JNI
     val isConnected
         get() = _isConnected.value
@@ -37,6 +48,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
     val currentDnsServers: ArrayList<InetAddress>
         get() = _mutableNetworkState.value?.dnsServers ?: ArrayList()
 
+    @OptIn(FlowPreview::class)
     fun register(scope: CoroutineScope) {
         // Consider implementing retry logic for the flows below, because registering a listener on
         // the default network may fail if the network on Android 11
@@ -53,12 +65,19 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
 
         _isConnected =
             connectivityManager
-                .hasInternetConnectivity()
-                .onEach { notifyConnectivityChange(it) }
+                .hasInternetConnectivity(resolver)
+                .onEach { notifyConnectivityChange(it.ipv4, it.ipv6) }
                 .stateIn(
-                    scope,
+                    scope + Dispatchers.IO,
                     SharingStarted.Eagerly,
-                    true, // Assume we have internet until we know otherwise
+                    // Has to happen on IO to avoid NetworkOnMainThreadException, we actually don't
+                    // send any traffic just open a socket to detect the IP version.
+                    runBlocking(Dispatchers.IO) {
+                        resolveConnectivityStatus(
+                            connectivityManager.activeRawNetworkState(),
+                            resolver,
+                        )
+                    },
                 )
     }
 
@@ -80,7 +99,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
             linkProperties?.dnsServersWithoutFallback(),
         )
 
-    private external fun notifyConnectivityChange(isConnected: Boolean)
+    private external fun notifyConnectivityChange(isIPv4: Boolean, isIPv6: Boolean)
 
     private external fun notifyDefaultNetworkChange(networkState: NetworkState?)
 }
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
index a227c9a77016..e353b8cc552c 100644
--- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt
@@ -26,6 +26,7 @@ import net.mullvad.talpid.model.CreateTunResult.OtherAlwaysOnApp
 import net.mullvad.talpid.model.CreateTunResult.OtherLegacyAlwaysOnVpn
 import net.mullvad.talpid.model.TunConfig
 import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported
+import net.mullvad.talpid.util.UnderlyingConnectivityStatusResolver
 
 open class TalpidVpnService : LifecycleVpnService() {
     private var activeTunStatus by
@@ -48,7 +49,11 @@ open class TalpidVpnService : LifecycleVpnService() {
     @CallSuper
     override fun onCreate() {
         super.onCreate()
-        connectivityListener = ConnectivityListener(getSystemService<ConnectivityManager>()!!)
+        connectivityListener =
+            ConnectivityListener(
+                getSystemService<ConnectivityManager>()!!,
+                UnderlyingConnectivityStatusResolver(::protect),
+            )
         connectivityListener.register(lifecycleScope)
     }
 
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/Connectivity.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/Connectivity.kt
new file mode 100644
index 000000000000..b87eaaacc8d9
--- /dev/null
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/Connectivity.kt
@@ -0,0 +1,8 @@
+package net.mullvad.talpid.model
+
+sealed class Connectivity {
+    data class Status(val ipv4: Boolean, val ipv6: Boolean) : Connectivity()
+
+    // Required by jni
+    data object PresumeOnline : Connectivity()
+}
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt
index 89ddf425f58f..7a0208eaa1b7 100644
--- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt
@@ -5,8 +5,9 @@ import android.net.ConnectivityManager.NetworkCallback
 import android.net.LinkProperties
 import android.net.Network
 import android.net.NetworkCapabilities
-import android.net.NetworkRequest
 import co.touchlab.kermit.Logger
+import java.net.Inet4Address
+import java.net.Inet6Address
 import kotlin.time.Duration.Companion.milliseconds
 import kotlinx.coroutines.FlowPreview
 import kotlinx.coroutines.channels.awaitClose
@@ -16,13 +17,12 @@ import kotlinx.coroutines.flow.callbackFlow
 import kotlinx.coroutines.flow.debounce
 import kotlinx.coroutines.flow.distinctUntilChanged
 import kotlinx.coroutines.flow.map
-import kotlinx.coroutines.flow.mapNotNull
-import kotlinx.coroutines.flow.onStart
 import kotlinx.coroutines.flow.scan
+import net.mullvad.talpid.model.Connectivity
 
 private val CONNECTIVITY_DEBOUNCE = 300.milliseconds
 
-internal fun ConnectivityManager.defaultNetworkEvents(): Flow<NetworkEvent> = callbackFlow {
+fun ConnectivityManager.defaultNetworkEvents(): Flow<NetworkEvent> = callbackFlow {
     val callback =
         object : NetworkCallback() {
             override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) {
@@ -68,56 +68,6 @@ internal fun ConnectivityManager.defaultNetworkEvents(): Flow<NetworkEvent> = ca
     awaitClose { unregisterNetworkCallback(callback) }
 }
 
-fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow<NetworkEvent> =
-    callbackFlow {
-        val callback =
-            object : NetworkCallback() {
-                override fun onLinkPropertiesChanged(
-                    network: Network,
-                    linkProperties: LinkProperties,
-                ) {
-                    super.onLinkPropertiesChanged(network, linkProperties)
-                    trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
-                }
-
-                override fun onAvailable(network: Network) {
-                    super.onAvailable(network)
-                    trySendBlocking(NetworkEvent.Available(network))
-                }
-
-                override fun onCapabilitiesChanged(
-                    network: Network,
-                    networkCapabilities: NetworkCapabilities,
-                ) {
-                    super.onCapabilitiesChanged(network, networkCapabilities)
-                    trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities))
-                }
-
-                override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
-                    super.onBlockedStatusChanged(network, blocked)
-                    trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked))
-                }
-
-                override fun onLosing(network: Network, maxMsToLive: Int) {
-                    super.onLosing(network, maxMsToLive)
-                    trySendBlocking(NetworkEvent.Losing(network, maxMsToLive))
-                }
-
-                override fun onLost(network: Network) {
-                    super.onLost(network)
-                    trySendBlocking(NetworkEvent.Lost(network))
-                }
-
-                override fun onUnavailable() {
-                    super.onUnavailable()
-                    trySendBlocking(NetworkEvent.Unavailable)
-                }
-            }
-        registerNetworkCallback(networkRequest, callback)
-
-        awaitClose { unregisterNetworkCallback(callback) }
-    }
-
 internal fun ConnectivityManager.defaultRawNetworkStateFlow(): Flow<RawNetworkState?> =
     defaultNetworkEvents().scan(null as RawNetworkState?) { state, event -> state.reduce(event) }
 
@@ -153,7 +103,7 @@ sealed interface NetworkEvent {
     data class Lost(val network: Network) : NetworkEvent
 }
 
-internal data class RawNetworkState(
+data class RawNetworkState(
     val network: Network,
     val linkProperties: LinkProperties? = null,
     val networkCapabilities: NetworkCapabilities? = null,
@@ -161,66 +111,57 @@ internal data class RawNetworkState(
     val maxMsToLive: Int? = null,
 )
 
-private val nonVPNInternetNetworksRequest =
-    NetworkRequest.Builder()
-        .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
-        .build()
-
-private sealed interface InternalConnectivityEvent {
-    data class Available(val network: Network) : InternalConnectivityEvent
-
-    data class Lost(val network: Network) : InternalConnectivityEvent
-}
+internal fun ConnectivityManager.activeRawNetworkState(): RawNetworkState? =
+    try {
+        activeNetwork?.let { currentNetwork: Network ->
+            RawNetworkState(
+                network = currentNetwork,
+                linkProperties = getLinkProperties(currentNetwork),
+                networkCapabilities = getNetworkCapabilities(currentNetwork),
+            )
+        }
+    } catch (_: RuntimeException) {
+        Logger.e(
+            "Unable to get active network or properties and capabilities of the active network"
+        )
+        null
+    }
 
 /**
- * Return a flow notifying us if we have internet connectivity. Initial state will be taken from
- * `allNetworks` and then updated when network events occur. Important to note that `allNetworks`
- * may return a network that we never get updates from if turned off at the moment of the initial
- * query.
+ * Return a flow with the current internet connectivity status. The status is based on current
+ * default network and depending on if it is a VPN. If it is not a VPN we check the network
+ * properties directly and if it is a VPN we use a socket to check the underlying network. A
+ * debounce is applied to avoid emitting too many events and to avoid setting the app in an offline
+ * state when switching networks.
  */
 @OptIn(FlowPreview::class)
-fun ConnectivityManager.hasInternetConnectivity(): Flow<Boolean> =
-    networkEvents(nonVPNInternetNetworksRequest)
-        .mapNotNull {
-            when (it) {
-                is NetworkEvent.Available -> InternalConnectivityEvent.Available(it.network)
-                is NetworkEvent.Lost -> InternalConnectivityEvent.Lost(it.network)
-                else -> null
-            }
-        }
-        .scan(emptySet<Network>()) { networks, event ->
-            when (event) {
-                is InternalConnectivityEvent.Lost -> networks - event.network
-                is InternalConnectivityEvent.Available -> networks + event.network
-            }.also { Logger.d("Networks: $it") }
-        }
-        // NetworkEvents are slow, can several 100 millis to arrive. If we are online, we don't
-        // want to emit a false offline with the initial accumulator, so we wait a bit before
-        // emitting, and rely on `networksWithInternetConnectivity`.
-        //
-        // Also if our initial state was "online", but it just got turned off we might not see
-        // any updates for this network even though we already were registered for updated, and
-        // thus we can't drop initial value accumulator value.
+fun ConnectivityManager.hasInternetConnectivity(
+    resolver: UnderlyingConnectivityStatusResolver
+): Flow<Connectivity.Status> =
+    this.defaultRawNetworkStateFlow()
         .debounce(CONNECTIVITY_DEBOUNCE)
-        .onStart {
-            // We should not use this as initial state in scan, because it may contain networks
-            // that won't be included in `networkEvents` updates.
-            emit(networksWithInternetConnectivity().also { Logger.d("Networks (Initial): $it") })
-        }
-        .map { it.isNotEmpty() }
+        .map { resolveConnectivityStatus(it, resolver) }
         .distinctUntilChanged()
 
-@Suppress("DEPRECATION")
-fun ConnectivityManager.networksWithInternetConnectivity(): Set<Network> =
-    // Currently the use of `allNetworks` (which is deprecated in favor of listening to network
-    // events) is our only option because network events does not give us the initial state fast
-    // enough.
-    allNetworks
-        .filter {
-            val capabilities = getNetworkCapabilities(it) ?: return@filter false
-
-            capabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) &&
-                capabilities.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-        }
-        .toSet()
+internal fun resolveConnectivityStatus(
+    currentRawNetworkState: RawNetworkState?,
+    resolver: UnderlyingConnectivityStatusResolver,
+): Connectivity.Status =
+    if (currentRawNetworkState.isVpn()) {
+        // If the default network is a VPN we need to use a socket to check
+        // the underlying network
+        resolver.currentStatus()
+    } else {
+        // If the default network is not a VPN we can check the addresses
+        // directly
+        currentRawNetworkState.toConnectivityStatus()
+    }
+
+private fun RawNetworkState?.toConnectivityStatus() =
+    Connectivity.Status(
+        ipv4 = this?.linkProperties?.linkAddresses?.any { it.address is Inet4Address } == true,
+        ipv6 = this?.linkProperties?.linkAddresses?.any { it.address is Inet6Address } == true,
+    )
+
+private fun RawNetworkState?.isVpn(): Boolean =
+    this?.networkCapabilities?.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) == false
diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/UnderlyingConnectivityStatusResolver.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/UnderlyingConnectivityStatusResolver.kt
new file mode 100644
index 000000000000..620288fb7295
--- /dev/null
+++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/UnderlyingConnectivityStatusResolver.kt
@@ -0,0 +1,69 @@
+package net.mullvad.talpid.util
+
+import arrow.core.Either
+import arrow.core.raise.result
+import co.touchlab.kermit.Logger
+import java.net.DatagramSocket
+import java.net.Inet4Address
+import java.net.Inet6Address
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import net.mullvad.talpid.model.Connectivity
+
+/** This class is used to check the ip version of the underlying network when a VPN is active. */
+class UnderlyingConnectivityStatusResolver(
+    private val protect: (socket: DatagramSocket) -> Boolean
+) {
+    fun currentStatus(): Connectivity.Status =
+        Connectivity.Status(ipv4 = hasIPv4(), ipv6 = hasIPv6())
+
+    private fun hasIPv4(): Boolean =
+        hasIpVersion(Inet4Address.getByName(PUBLIC_IPV4_ADDRESS), protect)
+
+    private fun hasIPv6(): Boolean =
+        hasIpVersion(Inet6Address.getByName(PUBLIC_IPV6_ADDRESS), protect)
+
+    // Fake a connection to a public ip address using a UDP socket.
+    // We don't care about the result of the connection, only that it is possible to create.
+    // This is done this way since otherwise there is not way to check the availability of an ip
+    // version on the underlying network if the VPN is turned on.
+    // Since we are protecting the socket it will use the underlying network regardless
+    // if the VPN is turned on or not.
+    // If the ip version is not supported on the underlying network it will trigger a socket
+    // exception. Otherwise we assume it is available.
+    private fun hasIpVersion(
+        ip: InetAddress,
+        protect: (socket: DatagramSocket) -> Boolean,
+    ): Boolean =
+        result {
+                // Open socket
+                val socket = openSocket().bind()
+
+                val protected = protect(socket)
+
+                // Protect so we can get underlying network
+                if (!protected) {
+                    // We shouldn't be doing this if we don't have a VPN, then we should of checked
+                    // the network directly.
+                    Logger.w("Failed to protect socket")
+                }
+
+                // "Connect" to public ip to see IP version is available
+                val address = InetSocketAddress(ip, 1)
+                socket.connectSafe(address).bind()
+            }
+            .isSuccess
+
+    private fun openSocket(): Either<Throwable, DatagramSocket> =
+        Either.catch { DatagramSocket() }.onLeft { Logger.e("Could not open socket or bind port") }
+
+    private fun DatagramSocket.connectSafe(address: InetSocketAddress): Either<Throwable, Unit> =
+        Either.catch { connect(address.address, address.port) }
+            .onLeft { Logger.e("Socket could not be set up") }
+            .also { close() }
+
+    companion object {
+        private const val PUBLIC_IPV4_ADDRESS = "1.1.1.1"
+        private const val PUBLIC_IPV6_ADDRESS = "2606:4700:4700::1001"
+    }
+}
diff --git a/mullvad-jni/src/classes.rs b/mullvad-jni/src/classes.rs
index f773d3adca24..6245bd901f7b 100644
--- a/mullvad-jni/src/classes.rs
+++ b/mullvad-jni/src/classes.rs
@@ -18,4 +18,6 @@ pub const CLASSES: &[&str] = &[
     "net/mullvad/talpid/ConnectivityListener",
     "net/mullvad/talpid/TalpidVpnService",
     "net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride",
+    "net/mullvad/talpid/model/Connectivity$Status",
+    "net/mullvad/talpid/model/Connectivity$PresumeOnline",
 ];
diff --git a/talpid-core/src/connectivity_listener.rs b/talpid-core/src/connectivity_listener.rs
index 767673123eac..1e6e504f7dac 100644
--- a/talpid-core/src/connectivity_listener.rs
+++ b/talpid-core/src/connectivity_listener.rs
@@ -98,36 +98,41 @@ impl ConnectivityListener {
 
     /// Return the current offline/connectivity state
     pub fn connectivity(&self) -> Connectivity {
-        self.get_is_connected()
-            .map(|connected| Connectivity::Status { connected })
-            .unwrap_or_else(|error| {
-                log::error!(
-                    "{}",
-                    error.display_chain_with_msg("Failed to check connectivity status")
-                );
-                Connectivity::PresumeOnline
-            })
+        self.get_is_connected().unwrap_or_else(|error| {
+            log::error!(
+                "{}",
+                error.display_chain_with_msg("Failed to check connectivity status")
+            );
+            Connectivity::PresumeOnline
+        })
     }
 
-    fn get_is_connected(&self) -> Result<bool, Error> {
+    fn get_is_connected(&self) -> Result<Connectivity, Error> {
         let env = JnixEnv::from(
             self.jvm
                 .attach_current_thread_as_daemon()
                 .map_err(Error::AttachJvmToThread)?,
         );
 
-        let is_connected =
-            env.call_method(self.android_listener.as_obj(), "isConnected", "()Z", &[]);
+        let is_connected = env.call_method(
+            self.android_listener.as_obj(),
+            "isConnected",
+            "()Lnet/mullvad/talpid/model/Connectivity;",
+            &[],
+        );
 
-        match is_connected {
-            Ok(JValue::Bool(JNI_TRUE)) => Ok(true),
-            Ok(JValue::Bool(_)) => Ok(false),
-            value => Err(Error::InvalidMethodResult(
-                "ConnectivityListener",
-                "isConnected",
-                format!("{:?}", value),
-            )),
-        }
+        let is_connected = match is_connected {
+            Ok(JValue::Object(object)) => object,
+            value => {
+                return Err(Error::InvalidMethodResult(
+                    "ConnectivityListener",
+                    "isConnected",
+                    format!("{:?}", value),
+                ))
+            }
+        };
+
+        Ok(Connectivity::from_java(&env, is_connected))
     }
 
     /// Return the current DNS servers according to Android
@@ -160,9 +165,10 @@ impl ConnectivityListener {
 #[unsafe(no_mangle)]
 #[allow(non_snake_case)]
 pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnectivityChange(
-    _: JNIEnv<'_>,
-    _: JObject<'_>,
-    connected: jboolean,
+    _env: JNIEnv<'_>,
+    _obj: JObject<'_>,
+    is_ipv4: jboolean,
+    is_ipv6: jboolean,
 ) {
     let Some(tx) = &*CONNECTIVITY_TX.lock().unwrap() else {
         // No sender has been registered
@@ -170,10 +176,14 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnec
         return;
     };
 
-    let connected = JNI_TRUE == connected;
+    let is_ipv4 = JNI_TRUE == is_ipv4;
+    let is_ipv6 = JNI_TRUE == is_ipv6;
 
     if tx
-        .unbounded_send(Connectivity::Status { connected })
+        .unbounded_send(Connectivity::Status {
+            ipv4: is_ipv4,
+            ipv6: is_ipv6,
+        })
         .is_err()
     {
         log::warn!("Failed to send offline change event");
diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs
index 8b57ba5410d5..19b04918e55d 100644
--- a/talpid-types/src/net/mod.rs
+++ b/talpid-types/src/net/mod.rs
@@ -1,4 +1,6 @@
 use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
+#[cfg(target_os = "android")]
+use jnix::FromJava;
 use obfuscation::ObfuscatorConfig;
 use serde::{Deserialize, Serialize};
 #[cfg(windows)]
@@ -566,19 +568,15 @@ pub fn all_of_the_internet() -> Vec<ipnetwork::IpNetwork> {
 /// Information about the host's connectivity, such as the preesence of
 /// configured IPv4 and/or IPv6.
 #[derive(Debug, Clone, Copy, PartialEq)]
+#[cfg_attr(target_os = "android", derive(FromJava))]
+#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.talpid.model"))]
 pub enum Connectivity {
-    #[cfg(not(target_os = "android"))]
     Status {
         /// Whether IPv4 connectivity seems to be available on the host.
         ipv4: bool,
         /// Whether IPv6 connectivity seems to be available on the host.
         ipv6: bool,
     },
-    #[cfg(target_os = "android")]
-    Status {
-        /// Whether _any_ connectivity seems to be available on the host.
-        connected: bool,
-    },
     /// On/offline status could not be verified, but we have no particular
     /// reason to believe that the host is offline.
     PresumeOnline,
@@ -592,7 +590,6 @@ impl Connectivity {
 
     /// If no IP4 nor IPv6 routes exist, we have no way of reaching the internet
     /// so we consider ourselves offline.
-    #[cfg(not(target_os = "android"))]
     pub fn is_offline(&self) -> bool {
         matches!(
             self,
@@ -606,23 +603,7 @@ impl Connectivity {
     /// Whether IPv6 connectivity seems to be available on the host.
     ///
     /// If IPv6 status is unknown, `false` is returned.
-    #[cfg(not(target_os = "android"))]
     pub fn has_ipv6(&self) -> bool {
         matches!(self, Connectivity::Status { ipv6: true, .. })
     }
-
-    /// Whether IPv6 connectivity seems to be available on the host.
-    ///
-    /// If IPv6 status is unknown, `false` is returned.
-    #[cfg(target_os = "android")]
-    pub fn has_ipv6(&self) -> bool {
-        self.is_online()
-    }
-
-    /// If the host does not have configured IPv6 routes, we have no way of
-    /// reaching the internet so we consider ourselves offline.
-    #[cfg(target_os = "android")]
-    pub fn is_offline(&self) -> bool {
-        matches!(self, Connectivity::Status { connected: false })
-    }
 }