Skip to content

Commit 53fe063

Browse files
tkrajacic0xTim
andauthored
Pass a copy of the control frame buffer to ping/pong callbacks (#116)
* Pass a copy of the control frame buffer to callbacks * Add back old API Since the new and old methods are only overloads and share the same name, the 'renamed' parameter of the deprecation warning doesn't help. * Allow specifying payload when sending ping * Remove default value in favor of method forwarding This preserves the signature of the original method and doesn't break the API * Apply suggestions from code review New APIs should use safe code --------- Co-authored-by: Tim Condon <0xTim@users.noreply.github.com>
1 parent 2ec1450 commit 53fe063

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift

+29-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ extension WebSocket {
1919
}
2020

2121
public func sendPing() async throws {
22+
try await sendPing(Data())
23+
}
24+
25+
public func sendPing(_ data: Data) async throws {
2226
let promise = eventLoop.makePromise(of: Void.self)
23-
sendPing(promise: promise)
27+
sendPing(data, promise: promise)
2428
return try await promise.futureResult.get()
2529
}
2630

@@ -60,19 +64,41 @@ extension WebSocket {
6064
}
6165
}
6266

67+
public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) {
68+
self.eventLoop.execute {
69+
self.onPong { socket, data in
70+
Task {
71+
await callback(socket, data)
72+
}
73+
}
74+
}
75+
}
76+
77+
@available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.")
6378
@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) async -> ()) {
6479
self.eventLoop.execute {
65-
self.onPong { socket in
80+
self.onPong { socket, _ in
6681
Task {
6782
await callback(socket)
6883
}
6984
}
7085
}
7186
}
7287

88+
public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) {
89+
self.eventLoop.execute {
90+
self.onPing { socket, data in
91+
Task {
92+
await callback(socket, data)
93+
}
94+
}
95+
}
96+
}
97+
98+
@available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.")
7399
@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) async -> ()) {
74100
self.eventLoop.execute {
75-
self.onPing { socket in
101+
self.onPing { socket, _ in
76102
Task {
77103
await callback(socket)
78104
}

Sources/WebSocketKit/WebSocket.swift

+23-9
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ public final class WebSocket: Sendable {
3434
internal let channel: Channel
3535
private let onTextCallback: NIOLoopBoundBox<@Sendable (WebSocket, String) -> ()>
3636
private let onBinaryCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
37-
private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()>
38-
private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()>
37+
private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
38+
private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
3939
private let type: PeerType
4040
private let waitingForPong: NIOLockedValueBox<Bool>
4141
private let waitingForClose: NIOLockedValueBox<Bool>
@@ -48,8 +48,8 @@ public final class WebSocket: Sendable {
4848
self.type = type
4949
self.onTextCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
5050
self.onBinaryCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
51-
self.onPongCallback = .init({ _ in }, eventLoop: channel.eventLoop)
52-
self.onPingCallback = .init({ _ in }, eventLoop: channel.eventLoop)
51+
self.onPongCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
52+
self.onPingCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
5353
self.waitingForPong = .init(false)
5454
self.waitingForClose = .init(false)
5555
self.scheduledTimeoutTask = .init(nil)
@@ -66,13 +66,23 @@ public final class WebSocket: Sendable {
6666
self.onBinaryCallback.value = callback
6767
}
6868

69-
@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) {
69+
public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) {
7070
self.onPongCallback.value = callback
7171
}
72+
73+
@available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.")
74+
@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) {
75+
self.onPongCallback.value = { ws, _ in callback(ws) }
76+
}
7277

73-
@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) {
78+
public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) {
7479
self.onPingCallback.value = callback
7580
}
81+
82+
@available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.")
83+
@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) {
84+
self.onPingCallback.value = { ws, _ in callback(ws) }
85+
}
7686

7787
/// If set, this will trigger automatic pings on the connection. If ping is not answered before
7888
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
@@ -112,8 +122,12 @@ public final class WebSocket: Sendable {
112122
}
113123

114124
public func sendPing(promise: EventLoopPromise<Void>? = nil) {
125+
sendPing(Data(), promise: promise)
126+
}
127+
128+
public func sendPing(_ data: Data, promise: EventLoopPromise<Void>? = nil) {
115129
self.send(
116-
raw: Data(),
130+
raw: data,
117131
opcode: .ping,
118132
fin: true,
119133
promise: promise
@@ -236,7 +250,7 @@ public final class WebSocket: Sendable {
236250
if let maskingKey = maskingKey {
237251
frameData.webSocketUnmask(maskingKey)
238252
}
239-
self.onPingCallback.value(self)
253+
self.onPingCallback.value(self, ByteBuffer(buffer: frameData))
240254
self.send(
241255
raw: frameData.readableBytesView,
242256
opcode: .pong,
@@ -254,7 +268,7 @@ public final class WebSocket: Sendable {
254268
frameData.webSocketUnmask(maskingKey)
255269
}
256270
self.waitingForPong.withLockedValue { $0 = false }
257-
self.onPongCallback.value(self)
271+
self.onPongCallback.value(self, ByteBuffer(buffer: frameData))
258272
} else {
259273
self.close(code: .protocolError, promise: nil)
260274
}

Tests/WebSocketKitTests/WebSocketKitTests.swift

+5-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ final class WebSocketKitTests: XCTestCase {
133133
let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8)
134134

135135
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
136-
ws.onPing { ws in
136+
ws.onPing { ws, data in
137+
XCTAssertEqual(pingPongData, data)
137138
pingPromise.succeed("ping")
138139
}
139140
}.bind(host: "localhost", port: 0).wait()
@@ -144,7 +145,9 @@ final class WebSocketKitTests: XCTestCase {
144145
}
145146

146147
WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
147-
ws.onPong { ws in
148+
ws.sendPing(Data(pingPongData.readableBytesView))
149+
ws.onPong { ws, data in
150+
XCTAssertEqual(pingPongData, data)
148151
pongPromise.succeed("pong")
149152
ws.close(promise: nil)
150153
}

0 commit comments

Comments
 (0)