7
7
8
8
from __future__ import absolute_import
9
9
from __future__ import unicode_literals
10
+
10
11
from builtins import object
11
12
from pyhive import common
12
13
from pyhive .common import DBAPITypeObject
@@ -79,7 +80,7 @@ class Cursor(common.DBAPICursor):
79
80
80
81
def __init__ (self , host , port = '8080' , username = None , catalog = 'hive' ,
81
82
schema = 'default' , poll_interval = 1 , source = 'pyhive' , session_props = None ,
82
- protocol = 'http' , password = None ):
83
+ protocol = 'http' , password = None , requests_session = None , requests_kwargs = None ):
83
84
"""
84
85
:param host: hostname to connect to, e.g. ``presto.example.com``
85
86
:param port: int -- port, defaults to 8080
@@ -91,29 +92,45 @@ def __init__(self, host, port='8080', username=None, catalog='hive',
91
92
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
92
93
:param protocol: string -- network protocol, valid options are ``http`` and ``https``.
93
94
defaults to ``http``
94
- :param password: string -- defaults to ``None``, using BasicAuth, requires ``https``
95
+ :param password: string -- Deprecated. Defaults to ``None``.
96
+ Using BasicAuth, requires ``https``.
97
+ Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``.
98
+ May not be specified with ``requests_kwargs``.
99
+ :param requests_session: a ``requests.Session`` object for advanced usage. If absent, this
100
+ class will use the default requests behavior of making a new session per HTTP request.
101
+ Caller is responsible for closing session.
102
+ :param requests_kwargs: Additional ``**kwargs`` to pass to requests
95
103
"""
96
104
super (Cursor , self ).__init__ (poll_interval )
97
105
# Config
98
106
self ._host = host
99
107
self ._port = port
100
108
self ._username = username or getpass .getuser ()
101
- self ._password = password
102
109
self ._catalog = catalog
103
110
self ._schema = schema
104
111
self ._arraysize = 1
105
112
self ._poll_interval = poll_interval
106
113
self ._source = source
107
114
self ._session_props = session_props if session_props is not None else {}
115
+
108
116
if protocol not in ('http' , 'https' ):
109
117
raise ValueError ("Protocol must be http/https, was {!r}" .format (protocol ))
110
118
self ._protocol = protocol
111
- if password is None :
112
- self ._auth = None
113
- else :
114
- self ._auth = HTTPBasicAuth (username , self ._password )
119
+
120
+ self ._requests_session = requests_session or requests
121
+
122
+ if password is not None and requests_kwargs is not None :
123
+ raise ValueError ("Cannot use both password and requests_kwargs" )
124
+ requests_kwargs = dict (requests_kwargs ) if requests_kwargs is not None else {}
125
+ for k in ('method' , 'url' , 'data' , 'headers' ):
126
+ if k in requests_kwargs :
127
+ raise ValueError ("Cannot override requests argument {}" .format (k ))
128
+ if password is not None :
129
+ requests_kwargs ['auth' ] = HTTPBasicAuth (username , password )
115
130
if protocol != 'https' :
116
131
raise ValueError ("Protocol must be https when passing a password" )
132
+ self ._requests_kwargs = requests_kwargs
133
+
117
134
self ._reset_state ()
118
135
119
136
def _reset_state (self ):
@@ -184,7 +201,8 @@ def execute(self, operation, parameters=None):
184
201
'{}:{}' .format (self ._host , self ._port ), '/v1/statement' , None , None , None ))
185
202
_logger .info ('%s' , sql )
186
203
_logger .debug ("Headers: %s" , headers )
187
- response = requests .post (url , data = sql .encode ('utf-8' ), headers = headers , auth = self ._auth )
204
+ response = self ._requests_session .post (
205
+ url , data = sql .encode ('utf-8' ), headers = headers , ** self ._requests_kwargs )
188
206
self ._process_response (response )
189
207
190
208
def cancel (self ):
@@ -194,7 +212,7 @@ def cancel(self):
194
212
assert self ._state == self ._STATE_FINISHED , "Should be finished if nextUri is None"
195
213
return
196
214
197
- response = requests . delete (self ._nextUri , auth = self ._auth )
215
+ response = self . _requests_session . delete (self ._nextUri , ** self ._requests_kwargs )
198
216
if response .status_code != requests .codes .no_content :
199
217
fmt = "Unexpected status code after cancel {}\n {}"
200
218
raise OperationalError (fmt .format (response .status_code , response .content ))
@@ -216,13 +234,13 @@ def poll(self):
216
234
if self ._nextUri is None :
217
235
assert self ._state == self ._STATE_FINISHED , "Should be finished if nextUri is None"
218
236
return None
219
- response = requests . get (self ._nextUri , auth = self ._auth )
237
+ response = self . _requests_session . get (self ._nextUri , ** self ._requests_kwargs )
220
238
self ._process_response (response )
221
239
return response .json ()
222
240
223
241
def _fetch_more (self ):
224
242
"""Fetch the next URI and update state"""
225
- self ._process_response (requests . get (self ._nextUri , auth = self ._auth ))
243
+ self ._process_response (self . _requests_session . get (self ._nextUri , ** self ._requests_kwargs ))
226
244
227
245
def _decode_binary (self , rows ):
228
246
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
0 commit comments