Skip to content

Add oauth2 client #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions inventree/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from requests.auth import HTTPBasicAuth
from requests.exceptions import Timeout

from . import oAuthClient as oauth

logger = logging.getLogger('inventree')


Expand Down Expand Up @@ -45,6 +47,9 @@ def __init__(self, host=None, **kwargs):
token - Authentication token (if provided, username/password are ignored)
token-name - Name of the token to use (default = 'inventree-python-client')
use_token_auth - Use token authentication? (default = True)
use_oidc_auth - Use OIDC authentication? (default = False)
oidc_client_id - OIDC client ID (defaults to InvenTree public client)
oidc_scopes - OIDC scopes (default = ['openid', 'g:read'])
verbose - Print extra debug messages (default = False)
strict - Enforce strict HTTPS certificate checking (default = True)
timeout - Set timeout to use (in seconds). Default: 10
Expand All @@ -56,6 +61,9 @@ def __init__(self, host=None, **kwargs):
INVENTREE_API_PASSWORD - Password
INVENTREE_API_TOKEN - User access token
INVENTREE_API_TIMEOUT - Timeout value, in seconds
INVENTREE_API_OIDC - Use OIDC
INVENTREE_API_OIDC_CLIENT_ID - OIDC client ID
INVENTREE_API_OIDC_SCOPES - OIDC scopes
"""

self.setHostName(host or os.environ.get('INVENTREE_API_HOST', None))
Expand All @@ -68,8 +76,13 @@ def __init__(self, host=None, **kwargs):
self.timeout = kwargs.get('timeout', os.environ.get('INVENTREE_API_TIMEOUT', 10))
self.proxies = kwargs.get('proxies', dict())
self.strict = bool(kwargs.get('strict', True))
self.oidc_client_id = kwargs.get('oidc_client_id', os.environ.get('INVENTREE_API_OIDC_CLIENT_ID', 'zDFnsiRheJIOKNx6aCQ0quBxECg1QBHtVFDPloJ6'))
self.oidc_scopes = kwargs.get('oidc_scopes', os.environ.get('INVENTREE_API_OIDC_SCOPES', ['openid', 'g:read']))

self.use_token_auth = kwargs.get('use_token_auth', True)
self.use_oidc_auth = kwargs.get('use_oidc_auth', os.environ.get('INVENTREE_API_OIDC', False))
if self.use_oidc_auth and self.use_token_auth:
self.use_token_auth = False
self.verbose = kwargs.get('verbose', False)

self.auth = None
Expand Down Expand Up @@ -126,15 +139,18 @@ def connect(self):
except Exception:
raise ConnectionRefusedError("Could not connect to InvenTree server")

if self.use_oidc_auth:
self.requestOidcToken()
return

# Basic authentication
self.auth = HTTPBasicAuth(self.username, self.password)

if not self.testAuth():
raise ConnectionError("Authentication at InvenTree server failed")

if self.use_token_auth:
if not self.token:
self.requestToken()
if self.use_token_auth and not self.token:
self.requestToken()

def constructApiUrl(self, endpoint_url):
"""Construct an API endpoint URL based on the provided API URL.
Expand Down Expand Up @@ -273,6 +289,13 @@ def requestToken(self):

return self.token

def requestOidcToken(self):
"""Return authentication token from the server using OIDC."""
client = oauth.OAuthClient(self.base_url, self.oidc_client_id, self.oidc_scopes)
self.token = client._access_token

return self.token

def request(self, api_url, **kwargs):
""" Perform a URL request to the Inventree API """

Expand Down Expand Up @@ -319,6 +342,9 @@ def request(self, api_url, **kwargs):
if self.use_token_auth and self.token:
headers['AUTHORIZATION'] = f'Token {self.token}'
auth = None
elif self.use_oidc_auth and self.token:
headers['AUTHORIZATION'] = f'Bearer {self.token}'
auth = None
else:
auth = self.auth

Expand Down Expand Up @@ -579,8 +605,9 @@ def downloadFile(self, url, destination, overwrite=False, params=None, proxies=d
raise FileExistsError(f"Destination file '{destination}' already exists")

if self.token:
headername = 'Token' if self.use_token_auth else 'Bearer'
headers = {
'AUTHORIZATION': f"Token {self.token}"
'AUTHORIZATION': f"{headername} {self.token}"
}
auth = None
else:
Expand Down
105 changes: 105 additions & 0 deletions inventree/oAuthClient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import urllib.parse as urlparse
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer

from requests_oauthlib import OAuth2Session

# Environment setup
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
USABLE_PORT_RANGE = (29170, 292180)


class OAuthClient:
def __init__(self, server_url: str = "http://localhost:8000", client_id: str = '', scopes: list = None) -> None:
self.server_url = server_url
self.client_id = client_id
self.scopes = scopes if scopes is not None else []

self._handler_wrapper = RequestHandlerWrapper(self)
self._setup_callback()
self._poll_user()

def get_url(self, path: str) -> str:
"""Get the authorization URL."""
return urlparse.urljoin(self.server_url, path)

def _setup_callback(self):
for port in range(*USABLE_PORT_RANGE):
try:
self.server = HTTPServer(("127.0.0.1", port), self._handler_wrapper.request_handler)
self._port = port
break
except OSError:
continue
else:
raise Exception("No port found.")

def _poll_user(self):
self._session = OAuth2Session(
self.client_id, scope=self.scopes, redirect_uri=f"http://localhost:{self._port}", pkce="S256"
)
auth_url, state = self._session.authorization_url(self.get_url('/o/authorize/'), access_type="offline")
self._state = state
webbrowser.open_new_tab(auth_url)

while not self._handler_wrapper.done:
self.server.handle_request()
if self._handler_wrapper.error:
raise Exception(self._handler_wrapper.error)

def callback(self, callback_url: str):
self._session.fetch_token(self.get_url("/o/token/"), authorization_response=callback_url, include_client_id=True)
self._access_token = self._session.access_token


class RequestHandlerWrapper:
"""Provides callback for OIDC endpoint."""
def __init__(self, oauth_client) -> None:
self.done = False
self.error = None
self.client: OAuthClient = oauth_client

@property
def request_handler(self):
wrapper = self

class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
parsed_url = urlparse.urlparse(self.path)
if parsed_url.path == "/":
error = urlparse.parse_qs(parsed_url.query).get("error", [None])[0]
if error:
wrapper.error = error
self.send(200)
else:
try:
wrapper.client.callback(self.path)
except OAuthError as e:
wrapper.error = e.message
self.send(400)
else:
self.send(200, 'Success! You can close this window.')
wrapper.done = True
else:
self.send(404)

def send(self, status_code, content=None):
self.send_response(status_code)
if content:
self.wfile.write(content.encode("utf-8"))
else:
self.wfile.write(b"")
self.send_header("Content-type", "text/html")
self.end_headers()

def log_message(self, *args):
pass # Suppress logging

return RequestHandler


class OAuthError(Exception):
"""Exception raised during the OAuth process."""
def __init__(self, message: str) -> None:
self.message = message
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"pip-system-certs>=4.0",
"requests>=2.27.0",
"urllib3>=2.3.0",
"requests-oauthlib",
]

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ invoke>=1.4.0
coverage>=6.4.1 # Run tests, measure coverage
coveralls>=3.3.1
Pillow>=9.1.1
requests-oauthlib # Modern auth experience