-
Notifications
You must be signed in to change notification settings - Fork 389
/
Copy pathEphemeralPeerExchangeActor.swift
116 lines (99 loc) · 4.07 KB
/
EphemeralPeerExchangeActor.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//
// EphemeralPeerExchangeActor.swift
// PacketTunnel
//
// Created by Marco Nikic on 2024-04-12.
// Copyright © 2024 Mullvad VPN AB. All rights reserved.
//
import Foundation
import MullvadRustRuntimeProxy
import MullvadTypes
import NetworkExtension
import WireGuardKitTypes
public protocol EphemeralPeerExchangeActorProtocol {
func startNegotiation(with privateKey: PrivateKey, enablePostQuantum: Bool, enableDaita: Bool)
func endCurrentNegotiation()
func reset()
}
public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol {
struct Negotiation {
var negotiator: EphemeralPeerNegotiating
func cancel() {
negotiator.cancelKeyNegotiation()
}
}
unowned let packetTunnel: any TunnelProvider
internal var negotiation: Negotiation?
private var timer: DispatchSourceTimer?
private var keyExchangeRetriesIterator: AnyIterator<Duration>!
private let iteratorProvider: () -> AnyIterator<Duration>
private let negotiationProvider: EphemeralPeerNegotiating.Type
// Callback in the event of the negotiation failing on startup
var onFailure: () -> Void
public init(
packetTunnel: any TunnelProvider,
onFailure: @escaping (() -> Void),
negotiationProvider: EphemeralPeerNegotiating.Type = EphemeralPeerNegotiator.self,
iteratorProvider: @escaping () -> AnyIterator<Duration>
) {
self.packetTunnel = packetTunnel
self.onFailure = onFailure
self.negotiationProvider = negotiationProvider
self.iteratorProvider = iteratorProvider
self.keyExchangeRetriesIterator = iteratorProvider()
}
/// Starts a new key exchange.
///
/// Any ongoing key negotiation is stopped before starting a new one.
/// An exponential backoff timer is used to stop the exchange if it takes too long,
/// or if the TCP connection takes too long to become ready.
/// It is reset after every successful key exchange.
///
/// - Parameter privateKey: The device's current private key
public func startNegotiation(with privateKey: PrivateKey, enablePostQuantum: Bool, enableDaita: Bool) {
endCurrentNegotiation()
let negotiator = negotiationProvider.init()
// This will become the new private key of the device
let ephemeralSharedKey = PrivateKey()
let tcpConnectionTimeout = keyExchangeRetriesIterator.next() ?? .seconds(10)
// If the connection never becomes viable, force a reconnection after 10 seconds
let peerParameters = EphemeralPeerParameters(
peer_exchange_timeout: UInt64(tcpConnectionTimeout.timeInterval),
enable_post_quantum: enablePostQuantum,
enable_daita: enableDaita,
funcs: mapWgFunctions(functions: packetTunnel.wgFunctions())
)
if !negotiator.startNegotiation(
devicePublicKey: privateKey.publicKey,
presharedKey: ephemeralSharedKey,
peerReceiver: packetTunnel,
ephemeralPeerParams: peerParameters
) {
// Cancel the negotiation to shut down any remaining use of the TCP connection on the Rust side
self.negotiation?.cancel()
self.negotiation = nil
self.onFailure()
}
negotiation = Negotiation(
negotiator: negotiator
)
}
private func mapWgFunctions(functions: WgFunctionPointers) -> WgTcpConnectionFunctions {
var mappedFunctions = WgTcpConnectionFunctions()
mappedFunctions.close_fn = functions.close
mappedFunctions.open_fn = functions.open
mappedFunctions.send_fn = functions.send
mappedFunctions.recv_fn = functions.receive
return mappedFunctions
}
/// Cancels the ongoing key exchange.
public func endCurrentNegotiation() {
negotiation?.cancel()
negotiation = nil
}
/// Resets the exponential timeout for successful key exchanges, and ends the current key exchange.
public func reset() {
keyExchangeRetriesIterator = iteratorProvider()
endCurrentNegotiation()
}
}