Skip to content

Commit bc4b8f2

Browse files
committed
Add option to force sending text or binary frames.
Fix #1515.
1 parent 21987f9 commit bc4b8f2

File tree

2 files changed

+134
-62
lines changed

2 files changed

+134
-62
lines changed

src/websockets/asyncio/connection.py

+70-54
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data:
251251
252252
You may override this behavior with the ``decode`` argument:
253253
254-
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames
255-
and return a bytestring (:class:`bytes`). This may be useful to
256-
optimize performance when decoding isn't needed.
254+
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
255+
return a bytestring (:class:`bytes`). This improves performance
256+
when decoding isn't needed, for example if the message contains
257+
JSON and you're using a JSON library that expects a bytestring.
257258
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
258-
and return a string (:class:`str`). This is useful for servers
259-
that send binary frames instead of text frames.
259+
and return a string (:class:`str`). This may be useful for
260+
servers that send binary frames instead of text frames.
260261
261262
Raises:
262263
ConnectionClosed: When the connection is closed.
@@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
333334
"is already running recv or recv_streaming"
334335
) from None
335336

336-
async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None:
337+
async def send(
338+
self,
339+
message: Data | Iterable[Data] | AsyncIterable[Data],
340+
text: bool | None = None,
341+
) -> None:
337342
"""
338343
Send a message.
339344
@@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
344349
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
345350
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
346351
352+
You may override this behavior with the ``text`` argument:
353+
354+
* Set ``text=True`` to send a bytestring or bytes-like object
355+
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a
356+
Text_ frame. This improves performance when the message is already
357+
UTF-8 encoded, for example if the message contains JSON and you're
358+
using a JSON library that produces a bytestring.
359+
* Set ``text=False`` to send a string (:class:`str`) in a Binary_
360+
frame. This may be useful for servers that expect binary frames
361+
instead of text frames.
362+
347363
:meth:`send` also accepts an iterable or an asynchronous iterable of
348364
strings, bytestrings, or bytes-like objects to enable fragmentation_.
349365
Each item is treated as a message fragment and sent in its own frame.
@@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
393409
# strings and bytes-like objects are iterable.
394410

395411
if isinstance(message, str):
396-
async with self.send_context():
397-
self.protocol.send_text(message.encode())
412+
if text is False:
413+
async with self.send_context():
414+
self.protocol.send_binary(message.encode())
415+
else:
416+
async with self.send_context():
417+
self.protocol.send_text(message.encode())
398418

399419
elif isinstance(message, BytesLike):
400-
async with self.send_context():
401-
self.protocol.send_binary(message)
420+
if text is True:
421+
async with self.send_context():
422+
self.protocol.send_text(message)
423+
else:
424+
async with self.send_context():
425+
self.protocol.send_binary(message)
402426

403427
# Catch a common mistake -- passing a dict to send().
404428

@@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
419443
try:
420444
# First fragment.
421445
if isinstance(chunk, str):
422-
text = True
423-
async with self.send_context():
424-
self.protocol.send_text(
425-
chunk.encode(),
426-
fin=False,
427-
)
446+
if text is False:
447+
async with self.send_context():
448+
self.protocol.send_binary(chunk.encode(), fin=False)
449+
else:
450+
async with self.send_context():
451+
self.protocol.send_text(chunk.encode(), fin=False)
452+
encode = True
428453
elif isinstance(chunk, BytesLike):
429-
text = False
430-
async with self.send_context():
431-
self.protocol.send_binary(
432-
chunk,
433-
fin=False,
434-
)
454+
if text is True:
455+
async with self.send_context():
456+
self.protocol.send_text(chunk, fin=False)
457+
else:
458+
async with self.send_context():
459+
self.protocol.send_binary(chunk, fin=False)
460+
encode = False
435461
else:
436462
raise TypeError("iterable must contain bytes or str")
437463

438464
# Other fragments
439465
for chunk in chunks:
440-
if isinstance(chunk, str) and text:
466+
if isinstance(chunk, str) and encode:
441467
async with self.send_context():
442-
self.protocol.send_continuation(
443-
chunk.encode(),
444-
fin=False,
445-
)
446-
elif isinstance(chunk, BytesLike) and not text:
468+
self.protocol.send_continuation(chunk.encode(), fin=False)
469+
elif isinstance(chunk, BytesLike) and not encode:
447470
async with self.send_context():
448-
self.protocol.send_continuation(
449-
chunk,
450-
fin=False,
451-
)
471+
self.protocol.send_continuation(chunk, fin=False)
452472
else:
453473
raise TypeError("iterable must contain uniform types")
454474

@@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
481501
try:
482502
# First fragment.
483503
if isinstance(chunk, str):
484-
text = True
485-
async with self.send_context():
486-
self.protocol.send_text(
487-
chunk.encode(),
488-
fin=False,
489-
)
504+
if text is False:
505+
async with self.send_context():
506+
self.protocol.send_binary(chunk.encode(), fin=False)
507+
else:
508+
async with self.send_context():
509+
self.protocol.send_text(chunk.encode(), fin=False)
510+
encode = True
490511
elif isinstance(chunk, BytesLike):
491-
text = False
492-
async with self.send_context():
493-
self.protocol.send_binary(
494-
chunk,
495-
fin=False,
496-
)
512+
if text is True:
513+
async with self.send_context():
514+
self.protocol.send_text(chunk, fin=False)
515+
else:
516+
async with self.send_context():
517+
self.protocol.send_binary(chunk, fin=False)
518+
encode = False
497519
else:
498520
raise TypeError("async iterable must contain bytes or str")
499521

500522
# Other fragments
501523
async for chunk in achunks:
502-
if isinstance(chunk, str) and text:
524+
if isinstance(chunk, str) and encode:
503525
async with self.send_context():
504-
self.protocol.send_continuation(
505-
chunk.encode(),
506-
fin=False,
507-
)
508-
elif isinstance(chunk, BytesLike) and not text:
526+
self.protocol.send_continuation(chunk.encode(), fin=False)
527+
elif isinstance(chunk, BytesLike) and not encode:
509528
async with self.send_context():
510-
self.protocol.send_continuation(
511-
chunk,
512-
fin=False,
513-
)
529+
self.protocol.send_continuation(chunk, fin=False)
514530
else:
515531
raise TypeError("async iterable must contain uniform types")
516532

tests/asyncio/test_connection.py

+64-8
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,13 @@ async def test_recv_binary(self):
190190
await self.remote_connection.send(b"\x01\x02\xfe\xff")
191191
self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff")
192192

193-
async def test_recv_encoded_text(self):
194-
"""recv receives an UTF-8 encoded text message."""
193+
async def test_recv_text_as_bytes(self):
194+
"""recv receives a text message as bytes."""
195195
await self.remote_connection.send("😀")
196196
self.assertEqual(await self.connection.recv(decode=False), "😀".encode())
197197

198-
async def test_recv_decoded_binary(self):
199-
"""recv receives an UTF-8 decoded binary message."""
198+
async def test_recv_binary_as_text(self):
199+
"""recv receives a binary message as a str."""
200200
await self.remote_connection.send("😀".encode())
201201
self.assertEqual(await self.connection.recv(decode=True), "😀")
202202

@@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self):
304304
[b"\x01\x02\xfe\xff"],
305305
)
306306

307-
async def test_recv_streaming_encoded_text(self):
308-
"""recv_streaming receives an UTF-8 encoded text message."""
307+
async def test_recv_streaming_text_as_bytes(self):
308+
"""recv_streaming receives a text message as bytes."""
309309
await self.remote_connection.send("😀")
310310
self.assertEqual(
311311
await alist(self.connection.recv_streaming(decode=False)),
312312
["😀".encode()],
313313
)
314314

315-
async def test_recv_streaming_decoded_binary(self):
316-
"""recv_streaming receives a UTF-8 decoded binary message."""
315+
async def test_recv_streaming_binary_as_str(self):
316+
"""recv_streaming receives a binary message as a str."""
317317
await self.remote_connection.send("😀".encode())
318318
self.assertEqual(
319319
await alist(self.connection.recv_streaming(decode=True)),
@@ -438,6 +438,16 @@ async def test_send_binary(self):
438438
await self.connection.send(b"\x01\x02\xfe\xff")
439439
self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff")
440440

441+
async def test_send_binary_from_str(self):
442+
"""send sends a binary message from a str."""
443+
await self.connection.send("😀", text=False)
444+
self.assertEqual(await self.remote_connection.recv(), "😀".encode())
445+
446+
async def test_send_text_from_bytes(self):
447+
"""send sends a text message from bytes."""
448+
await self.connection.send("😀".encode(), text=True)
449+
self.assertEqual(await self.remote_connection.recv(), "😀")
450+
441451
async def test_send_fragmented_text(self):
442452
"""send sends a fragmented text message."""
443453
await self.connection.send(["😀", "😀"])
@@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self):
456466
[b"\x01\x02", b"\xfe\xff", b""],
457467
)
458468

469+
async def test_send_fragmented_binary_from_str(self):
470+
"""send sends a fragmented binary message from a str."""
471+
await self.connection.send(["😀", "😀"], text=False)
472+
# websockets sends an trailing empty fragment. That's an implementation detail.
473+
self.assertEqual(
474+
await alist(self.remote_connection.recv_streaming()),
475+
["😀".encode(), "😀".encode(), b""],
476+
)
477+
478+
async def test_send_fragmented_text_from_bytes(self):
479+
"""send sends a fragmented text message from bytes."""
480+
await self.connection.send(["😀".encode(), "😀".encode()], text=True)
481+
# websockets sends an trailing empty fragment. That's an implementation detail.
482+
self.assertEqual(
483+
await alist(self.remote_connection.recv_streaming()),
484+
["😀", "😀", ""],
485+
)
486+
459487
async def test_send_async_fragmented_text(self):
460488
"""send sends a fragmented text message asynchronously."""
461489

@@ -484,6 +512,34 @@ async def fragments():
484512
[b"\x01\x02", b"\xfe\xff", b""],
485513
)
486514

515+
async def test_send_async_fragmented_binary_from_str(self):
516+
"""send sends a fragmented binary message from a str asynchronously."""
517+
518+
async def fragments():
519+
yield "😀"
520+
yield "😀"
521+
522+
await self.connection.send(fragments(), text=False)
523+
# websockets sends an trailing empty fragment. That's an implementation detail.
524+
self.assertEqual(
525+
await alist(self.remote_connection.recv_streaming()),
526+
["😀".encode(), "😀".encode(), b""],
527+
)
528+
529+
async def test_send_async_fragmented_text_from_bytes(self):
530+
"""send sends a fragmented text message from bytes asynchronously."""
531+
532+
async def fragments():
533+
yield "😀".encode()
534+
yield "😀".encode()
535+
536+
await self.connection.send(fragments(), text=True)
537+
# websockets sends an trailing empty fragment. That's an implementation detail.
538+
self.assertEqual(
539+
await alist(self.remote_connection.recv_streaming()),
540+
["😀", "😀", ""],
541+
)
542+
487543
async def test_send_connection_closed_ok(self):
488544
"""send raises ConnectionClosedOK after a normal closure."""
489545
await self.remote_connection.close()

0 commit comments

Comments
 (0)