Skip to content

Commit a4fe16f

Browse files
committed
Add new asyncio-based implementation.
1 parent e21ca9e commit a4fe16f

File tree

13 files changed

+947
-20
lines changed

13 files changed

+947
-20
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ tests:
1717
python -m unittest
1818

1919
coverage:
20-
coverage run --source src/websockets,tests -m unittest
20+
coverage run --source src/websockets,tests -m unittest tests/asyncio/test_messages.py
2121
coverage html
2222
coverage report --show-missing --fail-under=100
2323

docs/reference/index.rst

+12
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ clients concurrently.
2626
asyncio/server
2727
asyncio/client
2828

29+
:mod:`asyncio` (new)
30+
--------------------
31+
32+
This is a rewrite of the :mod:`asyncio` implementation. It will become the
33+
default implementation.
34+
35+
.. toctree::
36+
:titlesonly:
37+
38+
new-asyncio/server
39+
new-asyncio/client
40+
2941
:mod:`threading`
3042
----------------
3143

src/websockets/asyncio/client.py

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import socket
5+
import ssl
6+
from typing import Any, Optional, Sequence, Type, Union
7+
8+
from ..client import ClientProtocol
9+
from ..datastructures import HeadersLike
10+
from ..extensions.base import ClientExtensionFactory
11+
from ..extensions.permessage_deflate import enable_client_permessage_deflate
12+
from ..headers import validate_subprotocols
13+
from ..http import USER_AGENT
14+
from ..http11 import Response
15+
from ..protocol import CONNECTING, OPEN, Event
16+
from ..typing import LoggerLike, Origin, Subprotocol
17+
from ..uri import parse_uri
18+
from .compatibility import asyncio_timeout
19+
from .connection import Connection
20+
21+
22+
__all__ = ["connect", "unix_connect", "ClientConnection"]
23+
24+
25+
class ClientConnection(Connection):
26+
"""
27+
:mod:`asyncio` implementation of a WebSocket client connection.
28+
29+
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
30+
for receiving and sending messages.
31+
32+
It supports asynchronous iteration to receive messages::
33+
34+
async for message in websocket:
35+
await process(message)
36+
37+
The iterator exits normally when the connection is closed with close code
38+
1000 (OK) or 1001 (going away) or without a close code. It raises a
39+
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
40+
closed with any other code.
41+
42+
Args:
43+
protocol: Sans-I/O connection.
44+
close_timeout: Timeout for closing the connection in seconds.
45+
46+
"""
47+
48+
def __init__(
49+
self,
50+
socket: socket.socket,
51+
protocol: ClientProtocol,
52+
*,
53+
close_timeout: Optional[float] = 10,
54+
) -> None:
55+
self.protocol: ClientProtocol
56+
self.response_rcvd = asyncio.Event()
57+
super().__init__(
58+
protocol,
59+
close_timeout=close_timeout,
60+
)
61+
62+
async def handshake(
63+
self,
64+
additional_headers: Optional[HeadersLike] = None,
65+
user_agent_header: Optional[str] = USER_AGENT,
66+
) -> None:
67+
"""
68+
Perform the opening handshake.
69+
70+
"""
71+
async with self.send_context(expected_state=CONNECTING):
72+
self.request = self.protocol.connect()
73+
if additional_headers is not None:
74+
self.request.headers.update(additional_headers)
75+
if user_agent_header is not None:
76+
self.request.headers["User-Agent"] = user_agent_header
77+
self.protocol.send_request(self.request)
78+
79+
try:
80+
await self.response_rcvd.wait()
81+
except asyncio.CancelledError:
82+
self.close_transport()
83+
await self.recv_events_task
84+
raise
85+
86+
if self.response is None:
87+
self.close_transport()
88+
await self.recv_events_task
89+
raise ConnectionError("connection closed during handshake")
90+
91+
if self.protocol.state is not OPEN:
92+
try:
93+
async with asyncio_timeout(self.close_timeout):
94+
await self.recv_events_task
95+
except TimeoutError:
96+
pass
97+
self.close_transport()
98+
await self.recv_events_task
99+
100+
if self.protocol.handshake_exc is not None:
101+
raise self.protocol.handshake_exc
102+
103+
def process_event(self, event: Event) -> None:
104+
"""
105+
Process one incoming event.
106+
107+
"""
108+
# First event - handshake response.
109+
if self.response is None:
110+
assert isinstance(event, Response)
111+
self.response = event
112+
self.response_rcvd.set()
113+
# Later events - frames.
114+
else:
115+
super().process_event(event)
116+
117+
def recv_events(self) -> None:
118+
"""
119+
Read incoming data from the socket and process events.
120+
121+
"""
122+
try:
123+
super().recv_events()
124+
finally:
125+
# If the connection is closed during the handshake, unblock it.
126+
self.response_rcvd.set()
127+
128+
129+
async def connect(
130+
uri: str,
131+
*,
132+
# TCP/TLS — unix and path are only for unix_connect()
133+
sock: Optional[socket.socket] = None,
134+
ssl_context: Optional[ssl.SSLContext] = None,
135+
server_hostname: Optional[str] = None,
136+
unix: bool = False,
137+
path: Optional[str] = None,
138+
# WebSocket
139+
origin: Optional[Origin] = None,
140+
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
141+
subprotocols: Optional[Sequence[Subprotocol]] = None,
142+
additional_headers: Optional[HeadersLike] = None,
143+
user_agent_header: Optional[str] = USER_AGENT,
144+
compression: Optional[str] = "deflate",
145+
# Timeouts
146+
open_timeout: Optional[float] = 10,
147+
close_timeout: Optional[float] = 10,
148+
# Limits
149+
max_size: Optional[int] = 2**20,
150+
# Logging
151+
logger: Optional[LoggerLike] = None,
152+
# Escape hatch for advanced customization
153+
create_connection: Optional[Type[ClientConnection]] = None,
154+
# Other keyword arguments are passed to loop.create_connection
155+
**kwargs: Any,
156+
) -> ClientConnection:
157+
"""
158+
Connect to the WebSocket server at ``uri``.
159+
160+
This function returns a :class:`ClientConnection` instance, which you can
161+
use to send and receive messages.
162+
163+
:func:`connect` may be used as a context manager::
164+
165+
async with websockets.asyncio.client.connect(...) as websocket:
166+
...
167+
168+
The connection is closed automatically when exiting the context.
169+
170+
Args:
171+
uri: URI of the WebSocket server.
172+
sock: Preexisting TCP socket. ``sock`` overrides the host and port
173+
from ``uri``. You may call :func:`socket.create_connection` (not
174+
:func:`asyncio.create_connection`) to create a suitable TCP socket.
175+
ssl_context: Configuration for enabling TLS on the connection.
176+
server_hostname: Host name for the TLS handshake. ``server_hostname``
177+
overrides the host name from ``uri``.
178+
origin: Value of the ``Origin`` header, for servers that require it.
179+
extensions: List of supported extensions, in order in which they
180+
should be negotiated and run.
181+
subprotocols: List of supported subprotocols, in order of decreasing
182+
preference.
183+
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
184+
to the handshake request.
185+
user_agent_header: Value of the ``User-Agent`` request header.
186+
It defaults to ``"Python/x.y.z websockets/X.Y"``.
187+
Setting it to :obj:`None` removes the header.
188+
compression: The "permessage-deflate" extension is enabled by default.
189+
Set ``compression`` to :obj:`None` to disable it. See the
190+
:doc:`compression guide <../../topics/compression>` for details.
191+
open_timeout: Timeout for opening the connection in seconds.
192+
:obj:`None` disables the timeout.
193+
close_timeout: Timeout for closing the connection in seconds.
194+
:obj:`None` disables the timeout.
195+
max_size: Maximum size of incoming messages in bytes.
196+
:obj:`None` disables the limit.
197+
logger: Logger for this client.
198+
It defaults to ``logging.getLogger("websockets.client")``.
199+
See the :doc:`logging guide <../../topics/logging>` for details.
200+
create_connection: Factory for the :class:`ClientConnection` managing
201+
the connection. Set it to a wrapper or a subclass to customize
202+
connection handling.
203+
204+
Any other keyword arguments are passed the event loop's
205+
:meth:`~asyncio.loop.create_connection` method.
206+
207+
For example, you can set ``host`` and ``port`` to connect to a different
208+
host and port from those found in ``uri``. This only changes the destination
209+
of the TCP connection. The host name from ``uri`` is still used in the TLS
210+
handshake for secure connections and in the ``Host`` header.
211+
212+
Raises:
213+
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
214+
OSError: If the TCP connection fails.
215+
InvalidHandshake: If the opening handshake fails.
216+
TimeoutError: If the opening handshake times out.
217+
218+
"""
219+
220+
# Process parameters
221+
222+
wsuri = parse_uri(uri)
223+
if not wsuri.secure and ssl_context is not None:
224+
raise TypeError("ssl_context argument is incompatible with a ws:// URI")
225+
226+
ssl: Union[bool, Optional[ssl.SSLContext]] = ssl_context
227+
if wsuri.secure:
228+
if ssl is None:
229+
ssl = True
230+
if server_hostname is None:
231+
server_hostname = wsuri.host
232+
233+
if unix:
234+
if path is None and sock is None:
235+
raise TypeError("missing path argument")
236+
elif path is not None and sock is not None:
237+
raise TypeError("path and sock arguments are incompatible")
238+
else:
239+
assert path is None # private argument, only set by unix_connect()
240+
241+
if subprotocols is not None:
242+
validate_subprotocols(subprotocols)
243+
244+
if compression == "deflate":
245+
extensions = enable_client_permessage_deflate(extensions)
246+
elif compression is not None:
247+
raise ValueError(f"unsupported compression: {compression}")
248+
249+
protocol = ClientProtocol(
250+
wsuri,
251+
origin=origin,
252+
extensions=extensions,
253+
subprotocols=subprotocols,
254+
max_size=max_size,
255+
logger=logger,
256+
)
257+
258+
if create_connection is None:
259+
create_connection = ClientConnection
260+
261+
try:
262+
async with asyncio_timeout(open_timeout):
263+
if unix:
264+
_, connection = await asyncio.get_event_loop().create_unix_connection(
265+
lambda: create_connection(protocol, close_timeout=close_timeout),
266+
path=path,
267+
ssl=ssl,
268+
sock=sock,
269+
server_hostname=server_hostname,
270+
**kwargs,
271+
)
272+
else:
273+
_, connection = await asyncio.get_event_loop().create_connection(
274+
lambda: create_connection(protocol, close_timeout=close_timeout),
275+
ssl=ssl,
276+
sock=sock,
277+
server_hostname=server_hostname,
278+
**kwargs,
279+
)
280+
281+
# On failure, handshake() closes the transport and raises an exception.
282+
await connection.handshake(
283+
additional_headers,
284+
user_agent_header,
285+
)
286+
287+
except Exception:
288+
try:
289+
connection
290+
except NameError:
291+
pass
292+
else:
293+
connection.close_transport()
294+
raise
295+
296+
return connection
297+
298+
299+
async def unix_connect(
300+
path: Optional[str] = None,
301+
uri: Optional[str] = None,
302+
**kwargs: Any,
303+
) -> ClientConnection:
304+
"""
305+
Connect to a WebSocket server listening on a Unix socket.
306+
307+
This function is identical to :func:`connect`, except for the additional
308+
``path`` argument. It's only available on Unix.
309+
310+
It's mainly useful for debugging servers listening on Unix sockets.
311+
312+
Args:
313+
path: File system path to the Unix socket.
314+
uri: URI of the WebSocket server. ``uri`` defaults to
315+
``ws://localhost/`` or, when a ``ssl_context`` is provided, to
316+
``wss://localhost/``.
317+
318+
"""
319+
if uri is None:
320+
if kwargs.get("ssl_context") is None:
321+
uri = "ws://localhost/"
322+
else:
323+
uri = "wss://localhost/"
324+
return await connect(uri=uri, unix=True, path=path, **kwargs)

0 commit comments

Comments
 (0)