Skip to content

Commit d7dafcc

Browse files
committed
Add test coverage for interactive client.
1 parent 3c62503 commit d7dafcc

File tree

4 files changed

+283
-165
lines changed

4 files changed

+283
-165
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi
3535
Tracker = "https://github.com/python-websockets/websockets/issues"
3636

3737
[project.scripts]
38-
websockets = "websockets.__main__:main"
38+
websockets = "websockets.cli:main"
3939

4040
[tool.cibuildwheel]
4141
enable = ["pypy"]
@@ -69,6 +69,7 @@ exclude_lines = [
6969
"pragma: no cover",
7070
"except ImportError:",
7171
"if self.debug:",
72+
"if sys.platform == \"win32\":",
7273
"if sys.platform != \"win32\":",
7374
"if TYPE_CHECKING:",
7475
"raise AssertionError",

src/websockets/__main__.py

+1-164
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,4 @@
1-
from __future__ import annotations
2-
3-
import argparse
4-
import asyncio
5-
import os
6-
import sys
7-
from typing import Generator
8-
9-
from .asyncio.client import ClientConnection, connect
10-
from .asyncio.messages import SimpleQueue
11-
from .exceptions import ConnectionClosed
12-
from .frames import Close
13-
from .streams import StreamReader
14-
from .version import version as websockets_version
15-
16-
17-
def print_during_input(string: str) -> None:
18-
sys.stdout.write(
19-
# Save cursor position
20-
"\N{ESC}7"
21-
# Add a new line
22-
"\N{LINE FEED}"
23-
# Move cursor up
24-
"\N{ESC}[A"
25-
# Insert blank line, scroll last line down
26-
"\N{ESC}[L"
27-
# Print string in the inserted blank line
28-
f"{string}\N{LINE FEED}"
29-
# Restore cursor position
30-
"\N{ESC}8"
31-
# Move cursor down
32-
"\N{ESC}[B"
33-
)
34-
sys.stdout.flush()
35-
36-
37-
def print_over_input(string: str) -> None:
38-
sys.stdout.write(
39-
# Move cursor to beginning of line
40-
"\N{CARRIAGE RETURN}"
41-
# Delete current line
42-
"\N{ESC}[K"
43-
# Print string
44-
f"{string}\N{LINE FEED}"
45-
)
46-
sys.stdout.flush()
47-
48-
49-
class ReadLines(asyncio.Protocol):
50-
def __init__(self) -> None:
51-
self.reader = StreamReader()
52-
self.messages: SimpleQueue[str] = SimpleQueue()
53-
54-
def parse(self) -> Generator[None, None, None]:
55-
while True:
56-
sys.stdout.write("> ")
57-
sys.stdout.flush()
58-
line = yield from self.reader.read_line(sys.maxsize)
59-
self.messages.put(line.decode().rstrip("\r\n"))
60-
61-
def connection_made(self, transport: asyncio.BaseTransport) -> None:
62-
self.parser = self.parse()
63-
next(self.parser)
64-
65-
def data_received(self, data: bytes) -> None:
66-
self.reader.feed_data(data)
67-
next(self.parser)
68-
69-
def eof_received(self) -> None:
70-
self.reader.feed_eof()
71-
# next(self.parser) isn't useful and would raise EOFError.
72-
73-
def connection_lost(self, exc: Exception | None) -> None:
74-
self.reader.discard()
75-
self.messages.abort()
76-
77-
78-
async def print_incoming_messages(websocket: ClientConnection) -> None:
79-
async for message in websocket:
80-
if isinstance(message, str):
81-
print_during_input("< " + message)
82-
else:
83-
print_during_input("< (binary) " + message.hex())
84-
85-
86-
async def send_outgoing_messages(
87-
websocket: ClientConnection,
88-
messages: SimpleQueue[str],
89-
) -> None:
90-
while True:
91-
try:
92-
message = await messages.get()
93-
except EOFError:
94-
break
95-
try:
96-
await websocket.send(message)
97-
except ConnectionClosed:
98-
break
99-
100-
101-
async def interactive_client(uri: str) -> None:
102-
try:
103-
websocket = await connect(uri)
104-
except Exception as exc:
105-
print(f"Failed to connect to {uri}: {exc}.")
106-
sys.exit(1)
107-
else:
108-
print(f"Connected to {uri}.")
109-
110-
loop = asyncio.get_running_loop()
111-
transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin)
112-
incoming = asyncio.create_task(
113-
print_incoming_messages(websocket),
114-
)
115-
outgoing = asyncio.create_task(
116-
send_outgoing_messages(websocket, protocol.messages),
117-
)
118-
try:
119-
await asyncio.wait(
120-
[incoming, outgoing],
121-
return_when=asyncio.FIRST_COMPLETED,
122-
)
123-
except (KeyboardInterrupt, EOFError): # ^C, ^D
124-
pass
125-
finally:
126-
incoming.cancel()
127-
outgoing.cancel()
128-
transport.close()
129-
130-
await websocket.close()
131-
assert websocket.close_code is not None and websocket.close_reason is not None
132-
close_status = Close(websocket.close_code, websocket.close_reason)
133-
print_over_input(f"Connection closed: {close_status}.")
134-
135-
136-
def main() -> None:
137-
parser = argparse.ArgumentParser(
138-
prog="websockets",
139-
description="Interactive WebSocket client.",
140-
add_help=False,
141-
)
142-
group = parser.add_mutually_exclusive_group()
143-
group.add_argument("--version", action="store_true")
144-
group.add_argument("uri", metavar="<uri>", nargs="?")
145-
args = parser.parse_args()
146-
147-
if args.version:
148-
print(f"websockets {websockets_version}")
149-
return
150-
151-
if args.uri is None:
152-
parser.error("the following arguments are required: <uri>")
153-
154-
# Enable VT100 to support ANSI escape codes in Command Prompt on Windows.
155-
# See https://github.com/python/cpython/issues/74261 for why this works.
156-
if sys.platform == "win32": # pragma: no cover
157-
os.system("")
158-
159-
try:
160-
import readline # noqa: F401
161-
except ImportError: # readline isn't available on all platforms
162-
pass
163-
164-
asyncio.run(interactive_client(args.uri))
1+
from .cli import main
1652

1663

1674
if __name__ == "__main__":

src/websockets/cli.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import asyncio
5+
import os
6+
import sys
7+
from typing import Generator
8+
9+
from .asyncio.client import ClientConnection, connect
10+
from .asyncio.messages import SimpleQueue
11+
from .exceptions import ConnectionClosed
12+
from .frames import Close
13+
from .streams import StreamReader
14+
from .version import version as websockets_version
15+
16+
17+
__all__ = ["main"]
18+
19+
20+
def print_during_input(string: str) -> None:
21+
sys.stdout.write(
22+
# Save cursor position
23+
"\N{ESC}7"
24+
# Add a new line
25+
"\N{LINE FEED}"
26+
# Move cursor up
27+
"\N{ESC}[A"
28+
# Insert blank line, scroll last line down
29+
"\N{ESC}[L"
30+
# Print string in the inserted blank line
31+
f"{string}\N{LINE FEED}"
32+
# Restore cursor position
33+
"\N{ESC}8"
34+
# Move cursor down
35+
"\N{ESC}[B"
36+
)
37+
sys.stdout.flush()
38+
39+
40+
def print_over_input(string: str) -> None:
41+
sys.stdout.write(
42+
# Move cursor to beginning of line
43+
"\N{CARRIAGE RETURN}"
44+
# Delete current line
45+
"\N{ESC}[K"
46+
# Print string
47+
f"{string}\N{LINE FEED}"
48+
)
49+
sys.stdout.flush()
50+
51+
52+
class ReadLines(asyncio.Protocol):
53+
def __init__(self) -> None:
54+
self.reader = StreamReader()
55+
self.messages: SimpleQueue[str] = SimpleQueue()
56+
57+
def parse(self) -> Generator[None, None, None]:
58+
while True:
59+
sys.stdout.write("> ")
60+
sys.stdout.flush()
61+
line = yield from self.reader.read_line(sys.maxsize)
62+
self.messages.put(line.decode().rstrip("\r\n"))
63+
64+
def connection_made(self, transport: asyncio.BaseTransport) -> None:
65+
self.parser = self.parse()
66+
next(self.parser)
67+
68+
def data_received(self, data: bytes) -> None:
69+
self.reader.feed_data(data)
70+
next(self.parser)
71+
72+
def eof_received(self) -> None:
73+
self.reader.feed_eof()
74+
# next(self.parser) isn't useful and would raise EOFError.
75+
76+
def connection_lost(self, exc: Exception | None) -> None:
77+
self.reader.discard()
78+
self.messages.abort()
79+
80+
81+
async def print_incoming_messages(websocket: ClientConnection) -> None:
82+
async for message in websocket:
83+
if isinstance(message, str):
84+
print_during_input("< " + message)
85+
else:
86+
print_during_input("< (binary) " + message.hex())
87+
88+
89+
async def send_outgoing_messages(
90+
websocket: ClientConnection,
91+
messages: SimpleQueue[str],
92+
) -> None:
93+
while True:
94+
try:
95+
message = await messages.get()
96+
except EOFError:
97+
break
98+
try:
99+
await websocket.send(message)
100+
except ConnectionClosed: # pragma: no cover
101+
break
102+
103+
104+
async def interactive_client(uri: str) -> None:
105+
try:
106+
websocket = await connect(uri)
107+
except Exception as exc:
108+
print(f"Failed to connect to {uri}: {exc}.")
109+
sys.exit(1)
110+
else:
111+
print(f"Connected to {uri}.")
112+
113+
loop = asyncio.get_running_loop()
114+
transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin)
115+
incoming = asyncio.create_task(
116+
print_incoming_messages(websocket),
117+
)
118+
outgoing = asyncio.create_task(
119+
send_outgoing_messages(websocket, protocol.messages),
120+
)
121+
try:
122+
await asyncio.wait(
123+
[incoming, outgoing],
124+
return_when=asyncio.FIRST_COMPLETED,
125+
)
126+
except (KeyboardInterrupt, EOFError): # ^C, ^D # pragma: no cover
127+
pass
128+
finally:
129+
incoming.cancel()
130+
outgoing.cancel()
131+
transport.close()
132+
133+
await websocket.close()
134+
assert websocket.close_code is not None and websocket.close_reason is not None
135+
close_status = Close(websocket.close_code, websocket.close_reason)
136+
print_over_input(f"Connection closed: {close_status}.")
137+
138+
139+
def main(argv: list[str] | None = None) -> None:
140+
parser = argparse.ArgumentParser(
141+
prog="websockets",
142+
description="Interactive WebSocket client.",
143+
add_help=False,
144+
)
145+
group = parser.add_mutually_exclusive_group()
146+
group.add_argument("--version", action="store_true")
147+
group.add_argument("uri", metavar="<uri>", nargs="?")
148+
args = parser.parse_args(argv)
149+
150+
if args.version:
151+
print(f"websockets {websockets_version}")
152+
return
153+
154+
if args.uri is None:
155+
parser.print_usage()
156+
sys.exit(2)
157+
158+
# Enable VT100 to support ANSI escape codes in Command Prompt on Windows.
159+
# See https://github.com/python/cpython/issues/74261 for why this works.
160+
if sys.platform == "win32":
161+
os.system("")
162+
163+
try:
164+
import readline # noqa: F401
165+
except ImportError: # readline isn't available on all platforms
166+
pass
167+
168+
asyncio.run(interactive_client(args.uri))

0 commit comments

Comments
 (0)