Skip to content

Commit 1f019f9

Browse files
kapetandpgeorge
authored andcommitted
requests: Make possible to override headers and allow raw data upload.
This removes all the hard-coded request headers from the requests module so they can be overridden by user provided headers dict. Furthermore allow streaming request data without chunk encoding in those cases where content length is known but it's not desirable to load the whole content into memory. Also some servers (e.g. nginx) reject HTTP/1.0 requests with the Transfer-Encoding header set. The change should be backwards compatible as long as the user hasn't provided any of the previously hard-coded headers. Signed-off-by: Mirza Kapetanovic <mirza.kapetanovic@gmail.com>
1 parent 50ed36f commit 1f019f9

File tree

3 files changed

+193
-19
lines changed

3 files changed

+193
-19
lines changed

python-ecosys/requests/manifest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
metadata(version="0.9.0", pypi="requests")
1+
metadata(version="0.10.0", pypi="requests")
22

33
package("requests")

python-ecosys/requests/requests/__init__.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ def request(
3838
url,
3939
data=None,
4040
json=None,
41-
headers={},
41+
headers=None,
4242
stream=None,
4343
auth=None,
4444
timeout=None,
4545
parse_headers=True,
4646
):
47+
if headers is None:
48+
headers = {}
49+
4750
redirect = None # redirection url, None means no redirection
4851
chunked_data = data and getattr(data, "__next__", None) and not getattr(data, "__len__", None)
4952

@@ -94,33 +97,49 @@ def request(
9497
context.verify_mode = tls.CERT_NONE
9598
s = context.wrap_socket(s, server_hostname=host)
9699
s.write(b"%s /%s HTTP/1.0\r\n" % (method, path))
100+
97101
if "Host" not in headers:
98-
s.write(b"Host: %s\r\n" % host)
99-
# Iterate over keys to avoid tuple alloc
100-
for k in headers:
101-
s.write(k)
102-
s.write(b": ")
103-
s.write(headers[k])
104-
s.write(b"\r\n")
102+
headers["Host"] = host
103+
105104
if json is not None:
106105
assert data is None
107106
import ujson
108107

109108
data = ujson.dumps(json)
110-
s.write(b"Content-Type: application/json\r\n")
109+
110+
if "Content-Type" not in headers:
111+
headers["Content-Type"] = "application/json"
112+
111113
if data:
112114
if chunked_data:
113-
s.write(b"Transfer-Encoding: chunked\r\n")
114-
else:
115-
s.write(b"Content-Length: %d\r\n" % len(data))
116-
s.write(b"Connection: close\r\n\r\n")
115+
if "Transfer-Encoding" not in headers and "Content-Length" not in headers:
116+
headers["Transfer-Encoding"] = "chunked"
117+
elif "Content-Length" not in headers:
118+
headers["Content-Length"] = str(len(data))
119+
120+
if "Connection" not in headers:
121+
headers["Connection"] = "close"
122+
123+
# Iterate over keys to avoid tuple alloc
124+
for k in headers:
125+
s.write(k)
126+
s.write(b": ")
127+
s.write(headers[k])
128+
s.write(b"\r\n")
129+
130+
s.write(b"\r\n")
131+
117132
if data:
118133
if chunked_data:
119-
for chunk in data:
120-
s.write(b"%x\r\n" % len(chunk))
121-
s.write(chunk)
122-
s.write(b"\r\n")
123-
s.write("0\r\n\r\n")
134+
if headers.get("Transfer-Encoding", None) == "chunked":
135+
for chunk in data:
136+
s.write(b"%x\r\n" % len(chunk))
137+
s.write(chunk)
138+
s.write(b"\r\n")
139+
s.write("0\r\n\r\n")
140+
else:
141+
for chunk in data:
142+
s.write(chunk)
124143
else:
125144
s.write(data)
126145

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import io
2+
import sys
3+
4+
5+
class Socket:
6+
def __init__(self):
7+
self._write_buffer = io.BytesIO()
8+
self._read_buffer = io.BytesIO(b"HTTP/1.0 200 OK\r\n\r\n")
9+
10+
def connect(self, address):
11+
pass
12+
13+
def write(self, buf):
14+
self._write_buffer.write(buf)
15+
16+
def readline(self):
17+
return self._read_buffer.readline()
18+
19+
20+
class usocket:
21+
AF_INET = 2
22+
SOCK_STREAM = 1
23+
IPPROTO_TCP = 6
24+
25+
@staticmethod
26+
def getaddrinfo(host, port, af=0, type=0, flags=0):
27+
return [(usocket.AF_INET, usocket.SOCK_STREAM, usocket.IPPROTO_TCP, "", ("127.0.0.1", 80))]
28+
29+
def socket(af=AF_INET, type=SOCK_STREAM, proto=IPPROTO_TCP):
30+
return Socket()
31+
32+
33+
sys.modules["usocket"] = usocket
34+
# ruff: noqa: E402
35+
import requests
36+
37+
38+
def format_message(response):
39+
return response.raw._write_buffer.getvalue().decode("utf8")
40+
41+
42+
def test_simple_get():
43+
response = requests.request("GET", "http://example.com")
44+
45+
assert response.raw._write_buffer.getvalue() == (
46+
b"GET / HTTP/1.0\r\n" + b"Connection: close\r\n" + b"Host: example.com\r\n\r\n"
47+
), format_message(response)
48+
49+
50+
def test_get_auth():
51+
response = requests.request(
52+
"GET", "http://example.com", auth=("test-username", "test-password")
53+
)
54+
55+
assert response.raw._write_buffer.getvalue() == (
56+
b"GET / HTTP/1.0\r\n"
57+
+ b"Host: example.com\r\n"
58+
+ b"Authorization: Basic dGVzdC11c2VybmFtZTp0ZXN0LXBhc3N3b3Jk\r\n"
59+
+ b"Connection: close\r\n\r\n"
60+
), format_message(response)
61+
62+
63+
def test_get_custom_header():
64+
response = requests.request("GET", "http://example.com", headers={"User-Agent": "test-agent"})
65+
66+
assert response.raw._write_buffer.getvalue() == (
67+
b"GET / HTTP/1.0\r\n"
68+
+ b"User-Agent: test-agent\r\n"
69+
+ b"Host: example.com\r\n"
70+
+ b"Connection: close\r\n\r\n"
71+
), format_message(response)
72+
73+
74+
def test_post_json():
75+
response = requests.request("GET", "http://example.com", json="test")
76+
77+
assert response.raw._write_buffer.getvalue() == (
78+
b"GET / HTTP/1.0\r\n"
79+
+ b"Connection: close\r\n"
80+
+ b"Content-Type: application/json\r\n"
81+
+ b"Host: example.com\r\n"
82+
+ b"Content-Length: 6\r\n\r\n"
83+
+ b'"test"'
84+
), format_message(response)
85+
86+
87+
def test_post_chunked_data():
88+
def chunks():
89+
yield "test"
90+
91+
response = requests.request("GET", "http://example.com", data=chunks())
92+
93+
assert response.raw._write_buffer.getvalue() == (
94+
b"GET / HTTP/1.0\r\n"
95+
+ b"Transfer-Encoding: chunked\r\n"
96+
+ b"Host: example.com\r\n"
97+
+ b"Connection: close\r\n\r\n"
98+
+ b"4\r\ntest\r\n"
99+
+ b"0\r\n\r\n"
100+
), format_message(response)
101+
102+
103+
def test_overwrite_get_headers():
104+
response = requests.request(
105+
"GET", "http://example.com", headers={"Connection": "keep-alive", "Host": "test.com"}
106+
)
107+
108+
assert response.raw._write_buffer.getvalue() == (
109+
b"GET / HTTP/1.0\r\n" + b"Host: test.com\r\n" + b"Connection: keep-alive\r\n\r\n"
110+
), format_message(response)
111+
112+
113+
def test_overwrite_post_json_headers():
114+
response = requests.request(
115+
"GET",
116+
"http://example.com",
117+
json="test",
118+
headers={"Content-Type": "text/plain", "Content-Length": "10"},
119+
)
120+
121+
assert response.raw._write_buffer.getvalue() == (
122+
b"GET / HTTP/1.0\r\n"
123+
+ b"Connection: close\r\n"
124+
+ b"Content-Length: 10\r\n"
125+
+ b"Content-Type: text/plain\r\n"
126+
+ b"Host: example.com\r\n\r\n"
127+
+ b'"test"'
128+
), format_message(response)
129+
130+
131+
def test_overwrite_post_chunked_data_headers():
132+
def chunks():
133+
yield "test"
134+
135+
response = requests.request(
136+
"GET", "http://example.com", data=chunks(), headers={"Content-Length": "4"}
137+
)
138+
139+
assert response.raw._write_buffer.getvalue() == (
140+
b"GET / HTTP/1.0\r\n"
141+
+ b"Host: example.com\r\n"
142+
+ b"Content-Length: 4\r\n"
143+
+ b"Connection: close\r\n\r\n"
144+
+ b"test"
145+
), format_message(response)
146+
147+
148+
test_simple_get()
149+
test_get_auth()
150+
test_get_custom_header()
151+
test_post_json()
152+
test_post_chunked_data()
153+
test_overwrite_get_headers()
154+
test_overwrite_post_json_headers()
155+
test_overwrite_post_chunked_data_headers()

0 commit comments

Comments
 (0)