Skip to content

Allow HTTPS requests to use tls_ciphers #20179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions airflow/tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def test_service_checks_healthy_exp(aggregator, json_resp, expected_healthy_stat
check = AirflowCheck('airflow', common.FULL_CONFIG, [instance])

with mock.patch('datadog_checks.airflow.airflow.AirflowCheck._get_version', return_value=None):

with mock.patch('datadog_checks.base.utils.http.requests') as req:
mock_session = mock.MagicMock()
with mock.patch('datadog_checks.base.utils.http.requests.Session', return_value=mock_session):
mock_resp = mock.MagicMock(status_code=200)
mock_resp.json.side_effect = [json_resp]
req.get.return_value = mock_resp
mock_session.get.return_value = mock_resp

check.check(None)

Expand All @@ -60,14 +60,14 @@ def test_service_checks_healthy_stable(
check = AirflowCheck('airflow', common.FULL_CONFIG, [instance])

with mock.patch('datadog_checks.airflow.airflow.AirflowCheck._get_version', return_value='2.6.2'):

with mock.patch('datadog_checks.base.utils.http.requests') as req:
mock_session = mock.MagicMock()
with mock.patch('datadog_checks.base.utils.http.requests.Session', return_value=mock_session):
mock_resp = mock.MagicMock(status_code=200)
mock_resp.json.side_effect = [
{'metadatabase': {'status': metadb_status}, 'scheduler': {'status': scheduler_status}},
{'status': 'OK'},
]
req.get.return_value = mock_resp
mock_session.get.return_value = mock_resp

check.check(None)

Expand Down
2 changes: 1 addition & 1 deletion amazon_msk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def mock_requests_get(url, *args, **kwargs):

@pytest.fixture
def mock_data():
with mock.patch('requests.get', side_effect=mock_requests_get, autospec=True):
with mock.patch('requests.Session.get', side_effect=mock_requests_get, autospec=True):
yield


Expand Down
2 changes: 1 addition & 1 deletion arangodb/tests/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def mock_requests_get(url, *args, **kwargs):
fixture = url.rsplit('/', 1)[-1]
return MockResponse(file_path=os.path.join(os.path.dirname(__file__), 'fixtures', tag_condition, fixture))

with mock.patch('requests.get', side_effect=mock_requests_get, autospec=True):
with mock.patch('requests.Session.get', side_effect=mock_requests_get, autospec=True):
dd_run_check(check)

aggregator.assert_service_check(
Expand Down
4 changes: 2 additions & 2 deletions cert_manager/tests/test_cert_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.fixture()
def error_metrics():
with mock.patch(
'requests.get',
'requests.Session.get',
return_value=mock.MagicMock(status_code=502, headers={'Content-Type': "text/plain"}),
):
yield
Expand All @@ -34,7 +34,7 @@ def test_check(aggregator, dd_run_check):
def mock_requests_get(url, *args, **kwargs):
return MockResponse(file_path=os.path.join(os.path.dirname(__file__), 'fixtures', 'cert_manager.txt'))

with mock.patch('requests.get', side_effect=mock_requests_get, autospec=True):
with mock.patch('requests.Session.get', side_effect=mock_requests_get, autospec=True):
dd_run_check(check)

expected_metrics = dict(CERT_METRICS)
Expand Down
2 changes: 1 addition & 1 deletion citrix_hypervisor/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def mock_requests_get(url, *args, **kwargs):

@pytest.fixture
def mock_responses():
with mock.patch('requests.get', side_effect=mock_requests_get):
with mock.patch('requests.Session.get', side_effect=mock_requests_get):
yield
115 changes: 85 additions & 30 deletions datadog_checks_base/datadog_checks/base/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@

from datadog_checks.base.agent import datadog_agent
from datadog_checks.base.utils import _http_utils
from datadog_checks.base.utils.tls import TlsContextWrapper

from ..config import is_affirmative
from ..errors import ConfigurationError
from .common import ensure_bytes, ensure_unicode
from .headers import get_default_headers, update_headers
from .network import CertAdapter, create_socket_connection
from .network import create_socket_connection
from .time import get_timestamp

# See Performance Optimizations in this package's README.md.
Expand Down Expand Up @@ -85,12 +86,14 @@
'skip_proxy': False,
'tls_ca_cert': None,
'tls_cert': None,
'tls_ciphers': 'ALL',
'tls_use_host_header': False,
'tls_ignore_warning': False,
'tls_private_key': None,
'tls_private_key_password': None,
'tls_protocols_allowed': DEFAULT_PROTOCOL_VERSIONS,
'tls_validate_hostname': True,
'tls_verify': True,
'tls_ciphers': 'ALL',
'timeout': DEFAULT_TIMEOUT,
'use_legacy_auth_encoding': True,
'username': None,
Expand All @@ -115,6 +118,22 @@
UDS_SCHEME = 'unix'


class HTTPAdapterWrapper(requests.adapters.HTTPAdapter):
"""
HTTPS adapter that uses TlsContextWrapper to create SSL contexts.
This ensures consistent TLS configuration across all HTTPS requests.
"""

def __init__(self, tls_context_wrapper, **kwargs):
self.tls_context_wrapper = tls_context_wrapper
super(HTTPAdapterWrapper, self).__init__()

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
# Use the TLS context from TlsContextWrapper
pool_kwargs['ssl_context'] = self.tls_context_wrapper.tls_context
return super(HTTPAdapterWrapper, self).init_poolmanager(connections, maxsize, block=block, **pool_kwargs)


class ResponseWrapper(ObjectProxy):
def __init__(self, response, default_chunk_size):
super(ResponseWrapper, self).__init__(response)
Expand Down Expand Up @@ -152,7 +171,7 @@ class RequestsWrapper(object):
'auth_token_handler',
'request_size',
'tls_protocols_allowed',
'tls_ciphers_allowed',
'tls_context_wrapper',
)

def __init__(self, instance, init_config, remapper=None, logger=None, session=None):
Expand Down Expand Up @@ -254,7 +273,8 @@ def __init__(self, instance, init_config, remapper=None, logger=None, session=No

allow_redirects = is_affirmative(config['allow_redirects'])

# https://requests.readthedocs.io/en/latest/user/advanced/#ssl-cert-verification
# For TLS verification, we now rely on the TLS context wrapper
# but still need to set verify for requests compatibility
verify = True
if isinstance(config['tls_ca_cert'], str):
verify = config['tls_ca_cert']
Expand Down Expand Up @@ -347,13 +367,8 @@ def __init__(self, instance, init_config, remapper=None, logger=None, session=No
if config['kerberos_cache']:
self.request_hooks.append(lambda: handle_kerberos_cache(config['kerberos_cache']))

ciphers = config.get('tls_ciphers')
if ciphers:
if 'ALL' in ciphers:
updated_ciphers = "ALL"
else:
updated_ciphers = ":".join(ciphers)
self.tls_ciphers_allowed = updated_ciphers
# Create TLS context wrapper for consistent TLS configuration
self.tls_context_wrapper = TlsContextWrapper(config, remapper)

def get(self, url, **options):
return self._request('get', url, options)
Expand Down Expand Up @@ -397,6 +412,11 @@ def _request(self, method, url, options):
new_options['headers'] = new_options['headers'].copy()
new_options['headers'].update(extra_headers)

if new_options['verify'] != self.tls_context_wrapper.config['tls_verify']:
# The verify option needs to be synchronized
self.tls_context_wrapper.config['tls_verify'] = new_options['verify']
self.tls_context_wrapper.refresh_tls_context()

if is_uds_url(url):
persist = True # UDS support is only enabled on the shared session.
url = quote_uds_url(url)
Expand All @@ -409,7 +429,8 @@ def _request(self, method, url, options):
if persist:
request_method = getattr(self.session, method)
else:
request_method = getattr(requests, method)
# Create a new session for non-persistent requests
request_method = getattr(self.create_session(), method)

if self.auth_token_handler:
try:
Expand All @@ -435,16 +456,18 @@ def make_request_aia_chasing(self, request_method, method, url, new_options, per
certs = self.fetch_intermediate_certs(hostname, port)
if not certs:
raise e
self.tls_context_wrapper.config['tls_ca_cert'] = certs
self.tls_context_wrapper.refresh_tls_context()
# retry the connection via session object
certadapter = CertAdapter(certs=certs)
if not persist:
session = requests.Session()
for option, value in self.options.items():
setattr(session, option, value)
certadapter = HTTPAdapterWrapper(self.tls_context_wrapper)
session.mount(url, certadapter)
else:
session = self.session
request_method = getattr(session, method)
session.mount(url, certadapter)
response = request_method(url, **new_options)
return response

Expand Down Expand Up @@ -472,8 +495,17 @@ def fetch_intermediate_certs(self, hostname, port=443):
with sock:
try:
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS)
# Override verify mode for intermediate cert discovery
context.verify_mode = ssl.CERT_NONE
context.set_ciphers(self.tls_ciphers_allowed)
# Set the ciphers
ciphers = self.tls_context_wrapper.config.get('tls_ciphers')
if ciphers:
if 'ALL' in ciphers:
updated_ciphers = "ALL"
else:
updated_ciphers = ":".join(ciphers)

context.set_ciphers(updated_ciphers)

with context.wrap_socket(sock, server_hostname=hostname) as secure_sock:
der_cert = secure_sock.getpeercert(binary_form=True)
Expand Down Expand Up @@ -521,7 +553,7 @@ def load_intermediate_certs(self, der_cert, certs):

# Assume HTTP for now
try:
response = requests.get(uri) # SKIP_HTTP_VALIDATION
response = self.get(uri, verify=False) # SKIP_HTTP_VALIDATION
except Exception as e:
self.logger.error('Error fetching intermediate certificate from `%s`: %s', uri, e)
continue
Expand All @@ -532,23 +564,46 @@ def load_intermediate_certs(self, der_cert, certs):
self.load_intermediate_certs(intermediate_cert, certs)
return certs

@property
def session(self):
if self._session is None:
self._session = requests.Session()
def create_session(self):
session = requests.Session()

# Use TlsContextHTTPSAdapter for consistent TLS configuration
https_adapter = HTTPAdapterWrapper(self.tls_context_wrapper)

# Enables HostHeaderSSLAdapter if needed
# https://toolbelt.readthedocs.io/en/latest/adapters.html#hostheaderssladapter
if self.tls_use_host_header:
# Create a combined adapter that supports both TLS context and host headers
class TlsContextHostHeaderAdapter(HTTPAdapterWrapper, _http_utils.HostHeaderSSLAdapter):
def __init__(self, tls_context_wrapper, **kwargs):
HTTPAdapterWrapper.__init__(self, tls_context_wrapper, **kwargs)
_http_utils.HostHeaderSSLAdapter.__init__(self, **kwargs)

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
# Use TLS context from wrapper
pool_kwargs['ssl_context'] = self.tls_context_wrapper.tls_context
return _http_utils.HostHeaderSSLAdapter.init_poolmanager(
self, connections, maxsize, block=block, **pool_kwargs
)

# Enables HostHeaderSSLAdapter
# https://toolbelt.readthedocs.io/en/latest/adapters.html#hostheaderssladapter
if self.tls_use_host_header:
self._session.mount('https://', _http_utils.HostHeaderSSLAdapter())
# Enable Unix Domain Socket (UDS) support.
# See: https://github.com/msabramo/requests-unixsocket
self._session.mount('{}://'.format(UDS_SCHEME), requests_unixsocket.UnixAdapter())
https_adapter = TlsContextHostHeaderAdapter(self.tls_context_wrapper)

# Attributes can't be passed to the constructor
for option, value in self.options.items():
setattr(self._session, option, value)
session.mount('https://', https_adapter)

# Enable Unix Domain Socket (UDS) support.
# See: https://github.com/msabramo/requests-unixsocket
session.mount('{}://'.format(UDS_SCHEME), requests_unixsocket.UnixAdapter())

# Attributes can't be passed to the constructor
for option, value in self.options.items():
setattr(session, option, value)
return session

@property
def session(self):
if self._session is None:
# Create a new session if it doesn't exist
self._session = self.create_session()
return self._session

def handle_auth_token(self, **request):
Expand Down
16 changes: 0 additions & 16 deletions datadog_checks_base/datadog_checks/base/utils/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import socket
import ssl

from requests.adapters import HTTPAdapter, PoolManager


def create_socket_connection(hostname, port=443, sock_type=socket.SOCK_STREAM, timeout=10):
Expand Down Expand Up @@ -34,16 +31,3 @@ def create_socket_connection(hostname, port=443, sock_type=socket.SOCK_STREAM, t
raise socket.error('Unable to resolve host, check your DNS: {}'.format(message)) # noqa: G

raise


class CertAdapter(HTTPAdapter):
def __init__(self, **kwargs):
self.certs = kwargs['certs']
super(CertAdapter, self).__init__()

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
context = ssl.create_default_context()
for cert in self.certs:
context.load_verify_locations(cadata=cert)
pool_kwargs['ssl_context'] = context
self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, block=block, strict=True, **pool_kwargs)
15 changes: 7 additions & 8 deletions datadog_checks_base/datadog_checks/base/utils/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,17 @@ def _create_tls_context(self):
# type: () -> ssl.SSLContext

# https://docs.python.org/3/library/ssl.html#ssl.SSLContext
# https://docs.python.org/3/library/ssl.html#ssl.PROTOCOL_TLS
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS)
# https://docs.python.org/3/library/ssl.html#ssl.PROTOCOL_TLS_CLIENT
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)

# https://docs.python.org/3/library/ssl.html#ssl.SSLContext.check_hostname
context.check_hostname = (
False if not self.config['tls_verify'] else self.config.get('tls_validate_hostname', True)
)

# https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode
context.verify_mode = ssl.CERT_REQUIRED if self.config['tls_verify'] else ssl.CERT_NONE

# https://docs.python.org/3/library/ssl.html#ssl.SSLContext.check_hostname
if context.verify_mode == ssl.CERT_REQUIRED:
context.check_hostname = self.config.get('tls_validate_hostname', True)
else:
context.check_hostname = False

ciphers = self.config.get('tls_ciphers')
if ciphers:
if 'ALL' in ciphers:
Expand Down
Loading
Loading