Skip to content

Commit 17bf76f

Browse files
committed
fix: prevent sending expired tokens
1 parent ccfcbf5 commit 17bf76f

File tree

3 files changed

+124
-8
lines changed

3 files changed

+124
-8
lines changed

packages/realtime_client/lib/src/realtime_client.dart

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class RealtimeCloseEvent {
5454
}
5555

5656
class RealtimeClient {
57+
// This is named `accessTokenValue` in supabase-js
5758
String? accessToken;
5859
List<RealtimeChannel> channels = [];
5960
final String endPoint;
@@ -89,6 +90,8 @@ class RealtimeClient {
8990
};
9091
int longpollerTimeout = 20000;
9192
SocketStates? connState;
93+
// This is called `accessToken` in realtime-js
94+
Future<String> Function()? customAccessToken;
9295

9396
/// Initializes the Socket
9497
///
@@ -403,15 +406,42 @@ class RealtimeClient {
403406
/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
404407
///
405408
/// `token` A JWT strings.
406-
void setAuth(String? token) {
407-
accessToken = token;
409+
Future<void> setAuth(String? token) async {
410+
final tokenToSend =
411+
token ?? (await customAccessToken?.call()) ?? accessToken;
412+
413+
if (tokenToSend != null) {
414+
Map<String, dynamic>? parsed;
415+
try {
416+
final decoded =
417+
utf8.decode(base64Url.decode(tokenToSend.split('.')[1]));
418+
parsed = json.decode(decoded);
419+
} catch (e) {
420+
// ignore parsing errors
421+
}
422+
if (parsed != null && parsed['exp'] != null) {
423+
final now = (DateTime.now().millisecondsSinceEpoch / 1000).floor();
424+
final valid = now - parsed['exp'] < 0;
425+
if (!valid) {
426+
log(
427+
'auth',
428+
'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}',
429+
null,
430+
Level.FINE,
431+
);
432+
throw 'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}';
433+
}
434+
}
435+
}
436+
437+
accessToken = tokenToSend;
408438

409439
for (final channel in channels) {
410-
if (token != null) {
411-
channel.updateJoinPayload({'access_token': token});
440+
if (tokenToSend != null) {
441+
channel.updateJoinPayload({'access_token': tokenToSend});
412442
}
413443
if (channel.joinedOnce && channel.isJoined) {
414-
channel.push(ChannelEvents.accessToken, {'access_token': token});
444+
channel.push(ChannelEvents.accessToken, {'access_token': tokenToSend});
415445
}
416446
}
417447
}

packages/realtime_client/pubspec.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ dev_dependencies:
1919
lints: ^3.0.0
2020
mocktail: ^1.0.0
2121
test: ^1.16.5
22+
crypto: ^3.0.6

packages/realtime_client/test/socket_test.dart

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'dart:convert';
22
import 'dart:io';
33

4+
import 'package:crypto/crypto.dart';
45
import 'package:mocktail/mocktail.dart';
56
import 'package:realtime_client/realtime_client.dart';
67
import 'package:realtime_client/src/constants.dart';
@@ -16,6 +17,31 @@ typedef WebSocketChannelClosure = WebSocketChannel Function(
1617
Map<String, String> headers,
1718
);
1819

20+
/// Generate a JWT token for testing purposes
21+
///
22+
/// [exp] in seconds since Epoch
23+
String generateJwt([int? exp]) {
24+
final header = {'alg': 'HS256', 'typ': 'JWT'};
25+
26+
final now = DateTime.now();
27+
final expiry = exp ??
28+
(now.add(Duration(hours: 1)).millisecondsSinceEpoch / 1000).floor();
29+
30+
final payload = {'exp': expiry};
31+
32+
final key = 'your-256-bit-secret';
33+
34+
final encodedHeader = base64Url.encode(utf8.encode(json.encode(header)));
35+
final encodedPayload = base64Url.encode(utf8.encode(json.encode(payload)));
36+
37+
final signatureInput = '$encodedHeader.$encodedPayload';
38+
final hmac = Hmac(sha256, utf8.encode(key));
39+
final digest = hmac.convert(utf8.encode(signatureInput));
40+
final signature = base64Url.encode(digest.bytes);
41+
42+
return '$encodedHeader.$encodedPayload.$signature';
43+
}
44+
1945
void main() {
2046
const int int64MaxValue = 9223372036854775807;
2147

@@ -427,8 +453,9 @@ void main() {
427453
});
428454

429455
group('setAuth', () {
430-
final updateJoinPayload = {'access_token': 'token123'};
431-
final pushPayload = {'access_token': 'token123'};
456+
final token = generateJwt();
457+
final updateJoinPayload = {'access_token': token};
458+
final pushPayload = {'access_token': token};
432459

433460
test(
434461
"sets access token, updates channels' join payload, and pushes token to channels",
@@ -457,7 +484,9 @@ void main() {
457484
final channel1 = mockedSocket.channel(tTopic1);
458485
final channel2 = mockedSocket.channel(tTopic2);
459486

460-
mockedSocket.setAuth('token123');
487+
mockedSocket.setAuth(token);
488+
489+
expect(mockedSocket.accessToken, token);
461490

462491
verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
463492
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
@@ -466,6 +495,62 @@ void main() {
466495
verify(() => channel2.push(ChannelEvents.accessToken, pushPayload))
467496
.called(1);
468497
});
498+
499+
test(
500+
"sets access token, updates channels' join payload, and pushes token to channels if is not a jwt",
501+
() {
502+
final mockedChannel1 = MockChannel();
503+
final mockedChannel2 = MockChannel();
504+
final mockedChannel3 = MockChannel();
505+
506+
when(() => mockedChannel1.joinedOnce).thenReturn(true);
507+
when(() => mockedChannel1.isJoined).thenReturn(true);
508+
when(() => mockedChannel1.push(ChannelEvents.accessToken, any()))
509+
.thenReturn(MockPush());
510+
511+
when(() => mockedChannel2.joinedOnce).thenReturn(false);
512+
when(() => mockedChannel2.isJoined).thenReturn(false);
513+
when(() => mockedChannel2.push(ChannelEvents.accessToken, any()))
514+
.thenReturn(MockPush());
515+
516+
when(() => mockedChannel3.joinedOnce).thenReturn(true);
517+
when(() => mockedChannel3.isJoined).thenReturn(true);
518+
when(() => mockedChannel3.push(ChannelEvents.accessToken, any()))
519+
.thenReturn(MockPush());
520+
521+
const tTopic1 = 'test-topic1';
522+
const tTopic2 = 'test-topic2';
523+
const tTopic3 = 'test-topic3';
524+
525+
final mockedSocket = SocketWithMockedChannel(socketEndpoint);
526+
mockedSocket.mockedChannelLooker.addAll(<String, RealtimeChannel>{
527+
tTopic1: mockedChannel1,
528+
tTopic2: mockedChannel2,
529+
tTopic3: mockedChannel3,
530+
});
531+
532+
final channel1 = mockedSocket.channel(tTopic1);
533+
final channel2 = mockedSocket.channel(tTopic2);
534+
final channel3 = mockedSocket.channel(tTopic3);
535+
536+
const token = 'sb-key';
537+
final pushPayload = {'access_token': token};
538+
final updateJoinPayload = {'access_token': token};
539+
540+
mockedSocket.setAuth(token);
541+
542+
expect(mockedSocket.accessToken, token);
543+
544+
verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
545+
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
546+
verify(() => channel3.updateJoinPayload(updateJoinPayload)).called(1);
547+
548+
verify(() => channel1.push(ChannelEvents.accessToken, pushPayload))
549+
.called(1);
550+
verifyNever(() => channel2.push(ChannelEvents.accessToken, pushPayload));
551+
verify(() => channel3.push(ChannelEvents.accessToken, pushPayload))
552+
.called(1);
553+
});
469554
});
470555

471556
group('sendHeartbeat', () {

0 commit comments

Comments
 (0)