Skip to content

Commit a064a49

Browse files
committed
Add options to customize Thrift transport and requests kwargs (#140)
Closes #122, closes #132, closes #135
1 parent 2b0fbd6 commit a064a49

File tree

4 files changed

+156
-47
lines changed

4 files changed

+156
-47
lines changed

pyhive/hive.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,25 @@ def connect(*args, **kwargs):
6868
class Connection(object):
6969
"""Wraps a Thrift session"""
7070

71-
def __init__(self, host, port=10000, username=None, database='default', auth='NONE',
72-
configuration=None, kerberos_service_name=None, password=None):
71+
def __init__(self, host=None, port=None, username=None, database='default', auth=None,
72+
configuration=None, kerberos_service_name=None, password=None,
73+
thrift_transport=None):
7374
"""Connect to HiveServer2
7475
75-
:param auth: The value of hive.server2.authentication used by HiveServer2
76+
:param host: What host HiveServer2 runs on
77+
:param port: What port HiveServer2 runs on. Defaults to 10000.
78+
:param auth: The value of hive.server2.authentication used by HiveServer2.
79+
Defaults to ``NONE``.
7680
:param configuration: A dictionary of Hive settings (functionally same as the `set` command)
7781
:param kerberos_service_name: Use with auth='KERBEROS' only
7882
:param password: Use with auth='LDAP' only
83+
:param thrift_transport: A ``TTransportBase`` for custom advanced usage.
84+
Incompatible with host, port, auth, kerberos_service_name, and password.
7985
8086
The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
8187
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
8288
/impala/_thrift_api.py#L152-L160
8389
"""
84-
socket = thrift.transport.TSocket.TSocket(host, port)
8590
username = username or getpass.getuser()
8691
configuration = configuration or {}
8792

@@ -90,37 +95,56 @@ def __init__(self, host, port=10000, username=None, database='default', auth='NO
9095
"Remove password or add auth='LDAP'")
9196
if (kerberos_service_name is not None) != (auth == 'KERBEROS'):
9297
raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode")
98+
if thrift_transport is not None:
99+
has_incompatible_arg = (
100+
host is not None
101+
or port is not None
102+
or auth is not None
103+
or kerberos_service_name is not None
104+
or password is not None
105+
)
106+
if has_incompatible_arg:
107+
raise ValueError("thrift_transport cannot be used with "
108+
"host/port/auth/kerberos_service_name/password")
93109

94-
if auth == 'NOSASL':
95-
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
96-
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
97-
elif auth in ('LDAP', 'KERBEROS', 'NONE'):
98-
if auth == 'KERBEROS':
99-
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
100-
sasl_auth = 'GSSAPI'
101-
else:
102-
sasl_auth = 'PLAIN'
103-
if password is None:
104-
# Password doesn't matter in NONE mode, just needs to be nonempty.
105-
password = 'x'
106-
107-
def sasl_factory():
108-
sasl_client = sasl.Client()
109-
sasl_client.setAttr('host', host)
110-
if sasl_auth == 'GSSAPI':
111-
sasl_client.setAttr('service', kerberos_service_name)
112-
elif sasl_auth == 'PLAIN':
113-
sasl_client.setAttr('username', username)
114-
sasl_client.setAttr('password', password)
115-
else:
116-
raise AssertionError
117-
sasl_client.init()
118-
return sasl_client
119-
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
110+
if thrift_transport is not None:
111+
self._transport = thrift_transport
120112
else:
121-
raise NotImplementedError(
122-
"Only NONE, NOSASL, LDAP, KERBEROS "
123-
"authentication are supported, got {}".format(auth))
113+
if port is None:
114+
port = 10000
115+
if auth is None:
116+
auth = 'NONE'
117+
socket = thrift.transport.TSocket.TSocket(host, port)
118+
if auth == 'NOSASL':
119+
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
120+
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
121+
elif auth in ('LDAP', 'KERBEROS', 'NONE'):
122+
if auth == 'KERBEROS':
123+
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
124+
sasl_auth = 'GSSAPI'
125+
else:
126+
sasl_auth = 'PLAIN'
127+
if password is None:
128+
# Password doesn't matter in NONE mode, just needs to be nonempty.
129+
password = 'x'
130+
131+
def sasl_factory():
132+
sasl_client = sasl.Client()
133+
sasl_client.setAttr('host', host)
134+
if sasl_auth == 'GSSAPI':
135+
sasl_client.setAttr('service', kerberos_service_name)
136+
elif sasl_auth == 'PLAIN':
137+
sasl_client.setAttr('username', username)
138+
sasl_client.setAttr('password', password)
139+
else:
140+
raise AssertionError
141+
sasl_client.init()
142+
return sasl_client
143+
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
144+
else:
145+
raise NotImplementedError(
146+
"Only NONE, NOSASL, LDAP, KERBEROS "
147+
"authentication are supported, got {}".format(auth))
124148

125149
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
126150
self._client = TCLIService.Client(protocol)

pyhive/presto.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import absolute_import
99
from __future__ import unicode_literals
10+
1011
from builtins import object
1112
from pyhive import common
1213
from pyhive.common import DBAPITypeObject
@@ -79,7 +80,7 @@ class Cursor(common.DBAPICursor):
7980

8081
def __init__(self, host, port='8080', username=None, catalog='hive',
8182
schema='default', poll_interval=1, source='pyhive', session_props=None,
82-
protocol='http', password=None):
83+
protocol='http', password=None, requests_session=None, requests_kwargs=None):
8384
"""
8485
:param host: hostname to connect to, e.g. ``presto.example.com``
8586
:param port: int -- port, defaults to 8080
@@ -91,29 +92,45 @@ def __init__(self, host, port='8080', username=None, catalog='hive',
9192
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
9293
:param protocol: string -- network protocol, valid options are ``http`` and ``https``.
9394
defaults to ``http``
94-
:param password: string -- defaults to ``None``, using BasicAuth, requires ``https``
95+
:param password: string -- Deprecated. Defaults to ``None``.
96+
Using BasicAuth, requires ``https``.
97+
Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``.
98+
May not be specified with ``requests_kwargs``.
99+
:param requests_session: a ``requests.Session`` object for advanced usage. If absent, this
100+
class will use the default requests behavior of making a new session per HTTP request.
101+
Caller is responsible for closing session.
102+
:param requests_kwargs: Additional ``**kwargs`` to pass to requests
95103
"""
96104
super(Cursor, self).__init__(poll_interval)
97105
# Config
98106
self._host = host
99107
self._port = port
100108
self._username = username or getpass.getuser()
101-
self._password = password
102109
self._catalog = catalog
103110
self._schema = schema
104111
self._arraysize = 1
105112
self._poll_interval = poll_interval
106113
self._source = source
107114
self._session_props = session_props if session_props is not None else {}
115+
108116
if protocol not in ('http', 'https'):
109117
raise ValueError("Protocol must be http/https, was {!r}".format(protocol))
110118
self._protocol = protocol
111-
if password is None:
112-
self._auth = None
113-
else:
114-
self._auth = HTTPBasicAuth(username, self._password)
119+
120+
self._requests_session = requests_session or requests
121+
122+
if password is not None and requests_kwargs is not None:
123+
raise ValueError("Cannot use both password and requests_kwargs")
124+
requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {}
125+
for k in ('method', 'url', 'data', 'headers'):
126+
if k in requests_kwargs:
127+
raise ValueError("Cannot override requests argument {}".format(k))
128+
if password is not None:
129+
requests_kwargs['auth'] = HTTPBasicAuth(username, password)
115130
if protocol != 'https':
116131
raise ValueError("Protocol must be https when passing a password")
132+
self._requests_kwargs = requests_kwargs
133+
117134
self._reset_state()
118135

119136
def _reset_state(self):
@@ -184,7 +201,8 @@ def execute(self, operation, parameters=None):
184201
'{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
185202
_logger.info('%s', sql)
186203
_logger.debug("Headers: %s", headers)
187-
response = requests.post(url, data=sql.encode('utf-8'), headers=headers, auth=self._auth)
204+
response = self._requests_session.post(
205+
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
188206
self._process_response(response)
189207

190208
def cancel(self):
@@ -194,7 +212,7 @@ def cancel(self):
194212
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
195213
return
196214

197-
response = requests.delete(self._nextUri, auth=self._auth)
215+
response = self._requests_session.delete(self._nextUri, **self._requests_kwargs)
198216
if response.status_code != requests.codes.no_content:
199217
fmt = "Unexpected status code after cancel {}\n{}"
200218
raise OperationalError(fmt.format(response.status_code, response.content))
@@ -216,13 +234,13 @@ def poll(self):
216234
if self._nextUri is None:
217235
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
218236
return None
219-
response = requests.get(self._nextUri, auth=self._auth)
237+
response = self._requests_session.get(self._nextUri, **self._requests_kwargs)
220238
self._process_response(response)
221239
return response.json()
222240

223241
def _fetch_more(self):
224242
"""Fetch the next URI and update state"""
225-
self._process_response(requests.get(self._nextUri, auth=self._auth))
243+
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))
226244

227245
def _decode_binary(self, rows):
228246
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format

pyhive/tests/test_hive.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
import mock
1616
import os
17-
from TCLIService import ttypes
18-
from pyhive import hive
17+
import sasl
18+
import thrift.transport.TSocket
19+
import thrift.transport.TTransport
20+
import thrift_sasl
1921
from thrift.transport.TTransport import TTransportException
2022

23+
from TCLIService import ttypes
24+
from pyhive import hive
2125
from pyhive.tests.dbapi_test_case import DBAPITestCase
2226
from pyhive.tests.dbapi_test_case import with_cursor
2327

@@ -185,3 +189,30 @@ def test_invalid_kerberos_config(self):
185189
lambda: hive.connect(_HOST, kerberos_service_name=''))
186190
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
187191
lambda: hive.connect(_HOST, auth='KERBEROS'))
192+
193+
def test_invalid_transport(self):
194+
"""transport and auth are incompatible"""
195+
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
196+
transport = thrift.transport.TTransport.TBufferedTransport(socket)
197+
self.assertRaisesRegexp(
198+
ValueError, 'thrift_transport cannot be used with',
199+
lambda: hive.connect(_HOST, thrift_transport=transport)
200+
)
201+
202+
def test_custom_transport(self):
203+
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
204+
sasl_auth = 'PLAIN'
205+
206+
def sasl_factory():
207+
sasl_client = sasl.Client()
208+
sasl_client.setAttr('host', 'localhost')
209+
sasl_client.setAttr('username', 'test_username')
210+
sasl_client.setAttr('password', 'x')
211+
sasl_client.init()
212+
return sasl_client
213+
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
214+
conn = hive.connect(thrift_transport=transport)
215+
with contextlib.closing(conn):
216+
with contextlib.closing(conn.cursor()) as cursor:
217+
cursor.execute('SELECT * FROM one_row')
218+
self.assertEqual(cursor.fetchall(), [(1,)])

pyhive/tests/test_presto.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import contextlib
1111
import os
12+
import requests
1213

1314
from pyhive import exc
1415
from pyhive import presto
@@ -157,7 +158,7 @@ def test_set_session(self, cursor):
157158
session_prop = rows[0]
158159
assert session_prop[1] != '1234m'
159160

160-
def test_set_session_in_consructor(self):
161+
def test_set_session_in_constructor(self):
161162
conn = presto.connect(
162163
host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'}
163164
)
@@ -184,3 +185,38 @@ def test_invalid_protocol_config(self):
184185
ValueError, 'Protocol.*https.*password', lambda: presto.connect(
185186
host=_HOST, username='user', password='secret', protocol='http').cursor()
186187
)
188+
189+
def test_invalid_password_and_kwargs(self):
190+
"""password and requests_kwargs are incompatible"""
191+
self.assertRaisesRegexp(
192+
ValueError, 'Cannot use both', lambda: presto.connect(
193+
host=_HOST, username='user', password='secret', protocol='https',
194+
requests_kwargs={}
195+
).cursor()
196+
)
197+
198+
def test_invalid_kwargs(self):
199+
"""some kwargs are reserved"""
200+
self.assertRaisesRegexp(
201+
ValueError, 'Cannot override', lambda: presto.connect(
202+
host=_HOST, username='user', requests_kwargs={'url': 'test'}
203+
).cursor()
204+
)
205+
206+
def test_requests_kwargs(self):
207+
connection = presto.connect(
208+
host=_HOST, port=_PORT, source=self.id(),
209+
requests_kwargs={'proxies': {'http': 'localhost:99999'}},
210+
)
211+
cursor = connection.cursor()
212+
self.assertRaises(requests.exceptions.ProxyError,
213+
lambda: cursor.execute('SELECT * FROM one_row'))
214+
215+
def test_requests_session(self):
216+
with requests.Session() as session:
217+
connection = presto.connect(
218+
host=_HOST, port=_PORT, source=self.id(), requests_session=session
219+
)
220+
cursor = connection.cursor()
221+
cursor.execute('SELECT * FROM one_row')
222+
self.assertEqual(cursor.fetchall(), [(1,)])

0 commit comments

Comments
 (0)