Skip to content

Commit 748728d

Browse files
committed
requests: Make possible to override headers and allow raw data upload.
1 parent 50ed36f commit 748728d

File tree

2 files changed

+184
-18
lines changed

2 files changed

+184
-18
lines changed

python-ecosys/requests/requests/__init__.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ def request(
3838
url,
3939
data=None,
4040
json=None,
41-
headers={},
41+
headers=None,
4242
stream=None,
4343
auth=None,
4444
timeout=None,
4545
parse_headers=True,
4646
):
47+
if headers is None:
48+
headers = {}
49+
4750
redirect = None # redirection url, None means no redirection
4851
chunked_data = data and getattr(data, "__next__", None) and not getattr(data, "__len__", None)
4952

@@ -94,33 +97,49 @@ def request(
9497
context.verify_mode = tls.CERT_NONE
9598
s = context.wrap_socket(s, server_hostname=host)
9699
s.write(b"%s /%s HTTP/1.0\r\n" % (method, path))
100+
97101
if "Host" not in headers:
98-
s.write(b"Host: %s\r\n" % host)
99-
# Iterate over keys to avoid tuple alloc
100-
for k in headers:
101-
s.write(k)
102-
s.write(b": ")
103-
s.write(headers[k])
104-
s.write(b"\r\n")
102+
headers["Host"] = host
103+
105104
if json is not None:
106105
assert data is None
107106
import ujson
108107

109108
data = ujson.dumps(json)
110-
s.write(b"Content-Type: application/json\r\n")
109+
110+
if "Content-Type" not in headers:
111+
headers["Content-Type"] = "application/json"
112+
111113
if data:
112114
if chunked_data:
113-
s.write(b"Transfer-Encoding: chunked\r\n")
114-
else:
115-
s.write(b"Content-Length: %d\r\n" % len(data))
116-
s.write(b"Connection: close\r\n\r\n")
115+
if "Transfer-Encoding" not in headers and "Content-Length" not in headers:
116+
headers["Transfer-Encoding"] = "chunked"
117+
elif "Content-Length" not in headers:
118+
headers["Content-Length"] = str(len(data))
119+
120+
if "Connection" not in headers:
121+
headers["Connection"] = "close"
122+
123+
# Iterate over keys to avoid tuple alloc
124+
for k in headers:
125+
s.write(k)
126+
s.write(b": ")
127+
s.write(headers[k])
128+
s.write(b"\r\n")
129+
130+
s.write(b"\r\n")
131+
117132
if data:
118133
if chunked_data:
119-
for chunk in data:
120-
s.write(b"%x\r\n" % len(chunk))
121-
s.write(chunk)
122-
s.write(b"\r\n")
123-
s.write("0\r\n\r\n")
134+
if headers.get("Transfer-Encoding", None) == "chunked":
135+
for chunk in data:
136+
s.write(b"%x\r\n" % len(chunk))
137+
s.write(chunk)
138+
s.write(b"\r\n")
139+
s.write("0\r\n\r\n")
140+
else:
141+
for chunk in data:
142+
s.write(chunk)
124143
else:
125144
s.write(data)
126145

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import io
2+
import sys
3+
4+
5+
class Socket:
6+
def __init__(self):
7+
self._write_buffer = io.BytesIO()
8+
self._read_buffer = io.BytesIO(b"HTTP/1.0 200 OK\r\n\r\n")
9+
10+
def connect(self, address):
11+
pass
12+
13+
def write(self, buf):
14+
self._write_buffer.write(buf)
15+
16+
def readline(self):
17+
return self._read_buffer.readline()
18+
19+
20+
class usocket:
21+
AF_INET = 2
22+
SOCK_STREAM = 1
23+
IPPROTO_TCP = 6
24+
25+
@staticmethod
26+
def getaddrinfo(host, port, af=0, type=0, flags=0):
27+
return [(usocket.AF_INET, usocket.SOCK_STREAM, usocket.IPPROTO_TCP, "", ("127.0.0.1", 80))]
28+
29+
def socket(af=AF_INET, type=SOCK_STREAM, proto=IPPROTO_TCP):
30+
return Socket()
31+
32+
33+
sys.modules["usocket"] = usocket
34+
# ruff: noqa: E402
35+
import requests
36+
37+
38+
def format_message(response):
39+
return response.raw._write_buffer.getvalue().decode('utf8')
40+
41+
42+
def test_simple_get():
43+
response = requests.request("GET", "http://example.com")
44+
45+
assert response.raw._write_buffer.getvalue() == (
46+
b"GET / HTTP/1.0\r\n"
47+
+ b"Connection: close\r\n"
48+
+ b"Host: example.com\r\n\r\n"), format_message(response)
49+
50+
51+
def test_get_auth():
52+
response = requests.request("GET", "http://example.com",
53+
auth=("test-username", "test-password"))
54+
55+
assert response.raw._write_buffer.getvalue() == (
56+
b"GET / HTTP/1.0\r\n"
57+
+ b"Host: example.com\r\n"
58+
+ b"Authorization: Basic dGVzdC11c2VybmFtZTp0ZXN0LXBhc3N3b3Jk\r\n"
59+
+ b"Connection: close\r\n\r\n"), format_message(response)
60+
61+
62+
def test_get_custom_header():
63+
response = requests.request("GET", "http://example.com",
64+
headers={"User-Agent": "test-agent"})
65+
66+
assert response.raw._write_buffer.getvalue() == (
67+
b"GET / HTTP/1.0\r\n"
68+
+ b"User-Agent: test-agent\r\n"
69+
+ b"Host: example.com\r\n"
70+
+ b"Connection: close\r\n\r\n"), format_message(response)
71+
72+
73+
def test_post_json():
74+
response = requests.request("GET", "http://example.com", json="test")
75+
76+
assert response.raw._write_buffer.getvalue() == (
77+
b"GET / HTTP/1.0\r\n"
78+
+ b"Connection: close\r\n"
79+
+ b"Content-Type: application/json\r\n"
80+
+ b"Host: example.com\r\n"
81+
+ b"Content-Length: 6\r\n\r\n"
82+
+ b"\"test\""), format_message(response)
83+
84+
85+
def test_post_chunked_data():
86+
def chunks():
87+
yield "test"
88+
89+
response = requests.request("GET", "http://example.com", data=chunks())
90+
91+
assert response.raw._write_buffer.getvalue() == (
92+
b"GET / HTTP/1.0\r\n"
93+
+ b"Transfer-Encoding: chunked\r\n"
94+
+ b"Host: example.com\r\n"
95+
+ b"Connection: close\r\n\r\n"
96+
+ b"4\r\ntest\r\n"
97+
+ b"0\r\n\r\n"), format_message(response)
98+
99+
100+
def test_overwrite_get_headers():
101+
response = requests.request("GET", "http://example.com",
102+
headers={"Connection": "keep-alive", "Host": "test.com"})
103+
104+
assert response.raw._write_buffer.getvalue() == (
105+
b"GET / HTTP/1.0\r\n"
106+
+ b"Host: test.com\r\n"
107+
+ b"Connection: keep-alive\r\n\r\n"), format_message(response)
108+
109+
110+
def test_overwrite_post_json_headers():
111+
response = requests.request("GET", "http://example.com",
112+
json="test",
113+
headers={"Content-Type": "text/plain", "Content-Length": "10"})
114+
115+
assert response.raw._write_buffer.getvalue() == (
116+
b"GET / HTTP/1.0\r\n"
117+
+ b"Connection: close\r\n"
118+
+ b"Content-Length: 10\r\n"
119+
+ b"Content-Type: text/plain\r\n"
120+
+ b"Host: example.com\r\n\r\n"
121+
+ b"\"test\""), format_message(response)
122+
123+
124+
def test_overwrite_post_chunked_data_headers():
125+
def chunks():
126+
yield "test"
127+
128+
response = requests.request("GET", "http://example.com",
129+
data=chunks(),
130+
headers={"Content-Length": "4"})
131+
132+
assert response.raw._write_buffer.getvalue() == (
133+
b"GET / HTTP/1.0\r\n"
134+
+ b"Host: example.com\r\n"
135+
+ b"Content-Length: 4\r\n"
136+
+ b"Connection: close\r\n\r\n"
137+
+ b"test"), format_message(response)
138+
139+
140+
test_simple_get()
141+
test_get_auth()
142+
test_get_custom_header()
143+
test_post_json()
144+
test_post_chunked_data()
145+
test_overwrite_get_headers()
146+
test_overwrite_post_json_headers()
147+
test_overwrite_post_chunked_data_headers()

0 commit comments

Comments
 (0)