Skip to content

Commit 127d4b3

Browse files
committed
requests: Make possible to override headers and allow raw data upload.
1 parent 50ed36f commit 127d4b3

File tree

2 files changed

+192
-18
lines changed

2 files changed

+192
-18
lines changed

python-ecosys/requests/requests/__init__.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ def request(
3838
url,
3939
data=None,
4040
json=None,
41-
headers={},
41+
headers=None,
4242
stream=None,
4343
auth=None,
4444
timeout=None,
4545
parse_headers=True,
4646
):
47+
if headers is None:
48+
headers = {}
49+
4750
redirect = None # redirection url, None means no redirection
4851
chunked_data = data and getattr(data, "__next__", None) and not getattr(data, "__len__", None)
4952

@@ -94,33 +97,49 @@ def request(
9497
context.verify_mode = tls.CERT_NONE
9598
s = context.wrap_socket(s, server_hostname=host)
9699
s.write(b"%s /%s HTTP/1.0\r\n" % (method, path))
100+
97101
if "Host" not in headers:
98-
s.write(b"Host: %s\r\n" % host)
99-
# Iterate over keys to avoid tuple alloc
100-
for k in headers:
101-
s.write(k)
102-
s.write(b": ")
103-
s.write(headers[k])
104-
s.write(b"\r\n")
102+
headers["Host"] = host
103+
105104
if json is not None:
106105
assert data is None
107106
import ujson
108107

109108
data = ujson.dumps(json)
110-
s.write(b"Content-Type: application/json\r\n")
109+
110+
if "Content-Type" not in headers:
111+
headers["Content-Type"] = "application/json"
112+
111113
if data:
112114
if chunked_data:
113-
s.write(b"Transfer-Encoding: chunked\r\n")
114-
else:
115-
s.write(b"Content-Length: %d\r\n" % len(data))
116-
s.write(b"Connection: close\r\n\r\n")
115+
if "Transfer-Encoding" not in headers and "Content-Length" not in headers:
116+
headers["Transfer-Encoding"] = "chunked"
117+
elif "Content-Length" not in headers:
118+
headers["Content-Length"] = str(len(data))
119+
120+
if "Connection" not in headers:
121+
headers["Connection"] = "close"
122+
123+
# Iterate over keys to avoid tuple alloc
124+
for k in headers:
125+
s.write(k)
126+
s.write(b": ")
127+
s.write(headers[k])
128+
s.write(b"\r\n")
129+
130+
s.write(b"\r\n")
131+
117132
if data:
118133
if chunked_data:
119-
for chunk in data:
120-
s.write(b"%x\r\n" % len(chunk))
121-
s.write(chunk)
122-
s.write(b"\r\n")
123-
s.write("0\r\n\r\n")
134+
if headers.get("Transfer-Encoding", None) == "chunked":
135+
for chunk in data:
136+
s.write(b"%x\r\n" % len(chunk))
137+
s.write(chunk)
138+
s.write(b"\r\n")
139+
s.write("0\r\n\r\n")
140+
else:
141+
for chunk in data:
142+
s.write(chunk)
124143
else:
125144
s.write(data)
126145

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import io
2+
import sys
3+
4+
5+
class Socket:
6+
def __init__(self):
7+
self._write_buffer = io.BytesIO()
8+
self._read_buffer = io.BytesIO(b"HTTP/1.0 200 OK\r\n\r\n")
9+
10+
def connect(self, address):
11+
pass
12+
13+
def write(self, buf):
14+
self._write_buffer.write(buf)
15+
16+
def readline(self):
17+
return self._read_buffer.readline()
18+
19+
20+
class usocket:
21+
AF_INET = 2
22+
SOCK_STREAM = 1
23+
IPPROTO_TCP = 6
24+
25+
@staticmethod
26+
def getaddrinfo(host, port, af=0, type=0, flags=0):
27+
return [(usocket.AF_INET, usocket.SOCK_STREAM, usocket.IPPROTO_TCP, "", ("127.0.0.1", 80))]
28+
29+
def socket(af=AF_INET, type=SOCK_STREAM, proto=IPPROTO_TCP):
30+
return Socket()
31+
32+
33+
sys.modules["usocket"] = usocket
34+
# ruff: noqa: E402
35+
import requests
36+
37+
38+
def format_message(response):
39+
return response.raw._write_buffer.getvalue().decode("utf8")
40+
41+
42+
def test_simple_get():
43+
response = requests.request("GET", "http://example.com")
44+
45+
assert response.raw._write_buffer.getvalue() == (
46+
b"GET / HTTP/1.0\r\n" + b"Connection: close\r\n" + b"Host: example.com\r\n\r\n"
47+
), format_message(response)
48+
49+
50+
def test_get_auth():
51+
response = requests.request(
52+
"GET", "http://example.com", auth=("test-username", "test-password")
53+
)
54+
55+
assert response.raw._write_buffer.getvalue() == (
56+
b"GET / HTTP/1.0\r\n"
57+
+ b"Host: example.com\r\n"
58+
+ b"Authorization: Basic dGVzdC11c2VybmFtZTp0ZXN0LXBhc3N3b3Jk\r\n"
59+
+ b"Connection: close\r\n\r\n"
60+
), format_message(response)
61+
62+
63+
def test_get_custom_header():
64+
response = requests.request("GET", "http://example.com", headers={"User-Agent": "test-agent"})
65+
66+
assert response.raw._write_buffer.getvalue() == (
67+
b"GET / HTTP/1.0\r\n"
68+
+ b"User-Agent: test-agent\r\n"
69+
+ b"Host: example.com\r\n"
70+
+ b"Connection: close\r\n\r\n"
71+
), format_message(response)
72+
73+
74+
def test_post_json():
75+
response = requests.request("GET", "http://example.com", json="test")
76+
77+
assert response.raw._write_buffer.getvalue() == (
78+
b"GET / HTTP/1.0\r\n"
79+
+ b"Connection: close\r\n"
80+
+ b"Content-Type: application/json\r\n"
81+
+ b"Host: example.com\r\n"
82+
+ b"Content-Length: 6\r\n\r\n"
83+
+ b'"test"'
84+
), format_message(response)
85+
86+
87+
def test_post_chunked_data():
88+
def chunks():
89+
yield "test"
90+
91+
response = requests.request("GET", "http://example.com", data=chunks())
92+
93+
assert response.raw._write_buffer.getvalue() == (
94+
b"GET / HTTP/1.0\r\n"
95+
+ b"Transfer-Encoding: chunked\r\n"
96+
+ b"Host: example.com\r\n"
97+
+ b"Connection: close\r\n\r\n"
98+
+ b"4\r\ntest\r\n"
99+
+ b"0\r\n\r\n"
100+
), format_message(response)
101+
102+
103+
def test_overwrite_get_headers():
104+
response = requests.request(
105+
"GET", "http://example.com", headers={"Connection": "keep-alive", "Host": "test.com"}
106+
)
107+
108+
assert response.raw._write_buffer.getvalue() == (
109+
b"GET / HTTP/1.0\r\n" + b"Host: test.com\r\n" + b"Connection: keep-alive\r\n\r\n"
110+
), format_message(response)
111+
112+
113+
def test_overwrite_post_json_headers():
114+
response = requests.request(
115+
"GET",
116+
"http://example.com",
117+
json="test",
118+
headers={"Content-Type": "text/plain", "Content-Length": "10"},
119+
)
120+
121+
assert response.raw._write_buffer.getvalue() == (
122+
b"GET / HTTP/1.0\r\n"
123+
+ b"Connection: close\r\n"
124+
+ b"Content-Length: 10\r\n"
125+
+ b"Content-Type: text/plain\r\n"
126+
+ b"Host: example.com\r\n\r\n"
127+
+ b'"test"'
128+
), format_message(response)
129+
130+
131+
def test_overwrite_post_chunked_data_headers():
132+
def chunks():
133+
yield "test"
134+
135+
response = requests.request(
136+
"GET", "http://example.com", data=chunks(), headers={"Content-Length": "4"}
137+
)
138+
139+
assert response.raw._write_buffer.getvalue() == (
140+
b"GET / HTTP/1.0\r\n"
141+
+ b"Host: example.com\r\n"
142+
+ b"Content-Length: 4\r\n"
143+
+ b"Connection: close\r\n\r\n"
144+
+ b"test"
145+
), format_message(response)
146+
147+
148+
test_simple_get()
149+
test_get_auth()
150+
test_get_custom_header()
151+
test_post_json()
152+
test_post_chunked_data()
153+
test_overwrite_get_headers()
154+
test_overwrite_post_json_headers()
155+
test_overwrite_post_chunked_data_headers()

0 commit comments

Comments
 (0)