Skip to content

Commit

Permalink
add jwt decorators and roles
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Mar 11, 2024
1 parent eeb82ce commit 250971c
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 2 deletions.
63 changes: 63 additions & 0 deletions fedn/fedn/network/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
from functools import wraps

import jwt
from flask import jsonify, request

# Define your secret key for JWT
SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False)
FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False)
FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False)
FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer')

# Fuction to check additional claims in the token
def check_role_claims(payload, role):
if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE:
if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE:
return False
if 'role' not in payload:
return False
if payload['role'] != role:
return False

return True

# Fuction to check additional cliams in the token
def check_custom_claims(payload):
if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE:
if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE:
return False
return True

# Define the authentication decorator, with role as an argument
def jwt_auth_required(role=None):
def actual_decorator(func):
if not SECRET_KEY:
return func
@wraps(func)
def decorated(*args, **kwargs):
token = request.headers.get('Authorization')
# Get token from the header Bearer
if token and token.startswith(FEDN_AUTH_SCHEME):
token = token.split(' ')[1]

if not token:
return jsonify({'message': 'Missing token'}), 401

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256'])
if not check_role_claims(payload, role):
return jsonify({'message': 'Invalid token'}), 401
if not check_custom_claims(payload):
return jsonify({'message': 'Invalid token'}), 401

except jwt.ExpiredSignatureError:
return jsonify({'message': 'Token expired'}), 401

except jwt.InvalidTokenError:
return jsonify({'message': 'Invalid token'}), 401

return func(*args, **kwargs)

return decorated
return actual_decorator
2 changes: 1 addition & 1 deletion fedn/fedn/network/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth
# Auth scheme passed as argument overrides environment variable.
# "Token" is the default auth scheme.
if not auth_scheme:
auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token")
auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer")
# Override potential env variable if token is passed as argument.
if not token:
token = os.environ.get("FEDN_AUTH_TOKEN", False)
Expand Down
34 changes: 34 additions & 0 deletions fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fedn.common.config import (get_controller_config, get_modelstorage_config,
get_network_config, get_statestore_config)
from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.interface import API
from fedn.network.api.v1.client_routes import bp as client_bp
from fedn.network.api.v1.combiner_routes import bp as combiner_bp
Expand Down Expand Up @@ -45,6 +46,7 @@


@app.route("/get_model_trail", methods=["GET"])
@jwt_auth_required(role="admin")
def get_model_trail():
"""Get the model trail for a given session.
param: session: The session id to get the model trail for.
Expand All @@ -56,6 +58,7 @@ def get_model_trail():


@app.route("/get_model_ancestors", methods=["GET"])
@jwt_auth_required(role="admin")
def get_model_ancestors():
"""Get the ancestors of a model.
param: model: The model id to get the ancestors for.
Expand All @@ -72,6 +75,7 @@ def get_model_ancestors():


@app.route("/get_model_descendants", methods=["GET"])
@jwt_auth_required(role="admin")
def get_model_descendants():
"""Get the ancestors of a model.
param: model: The model id to get the child for.
Expand All @@ -88,6 +92,7 @@ def get_model_descendants():


@app.route("/list_models", methods=["GET"])
@jwt_auth_required(role="admin")
def list_models():
"""Get models from the statestore.
param:
Expand All @@ -109,6 +114,7 @@ def list_models():


@app.route("/get_model", methods=["GET"])
@jwt_auth_required(role="admin")
def get_model():
"""Get a model from the statestore.
param: model: The model id to get.
Expand All @@ -124,6 +130,7 @@ def get_model():


@app.route("/delete_model_trail", methods=["GET", "POST"])
@jwt_auth_required(role="admin")
def delete_model_trail():
"""Delete the model trail for a given session.
param: session: The session id to delete the model trail for.
Expand All @@ -135,6 +142,7 @@ def delete_model_trail():


@app.route("/list_clients", methods=["GET"])
@jwt_auth_required(role="admin")
def list_clients():
"""Get all clients from the statestore.
return: All clients as a json object.
Expand All @@ -149,6 +157,7 @@ def list_clients():


@app.route("/get_active_clients", methods=["GET"])
@jwt_auth_required(role="admin")
def get_active_clients():
"""Get all active clients from the statestore.
param: combiner_id: The combiner id to get active clients for.
Expand All @@ -166,6 +175,7 @@ def get_active_clients():


@app.route("/list_combiners", methods=["GET"])
@jwt_auth_required(role="admin")
def list_combiners():
"""Get all combiners in the network.
return: All combiners as a json object.
Expand All @@ -179,6 +189,7 @@ def list_combiners():


@app.route("/get_combiner", methods=["GET"])
@jwt_auth_required(role="admin")
def get_combiner():
"""Get a combiner from the statestore.
param: combiner_id: The combiner id to get.
Expand All @@ -196,6 +207,7 @@ def get_combiner():


@app.route("/list_rounds", methods=["GET"])
@jwt_auth_required(role="admin")
def list_rounds():
"""Get all rounds from the statestore.
return: All rounds as a json object.
Expand All @@ -205,6 +217,7 @@ def list_rounds():


@app.route("/get_round", methods=["GET"])
@jwt_auth_required(role="admin")
def get_round():
"""Get a round from the statestore.
param: round_id: The round id to get.
Expand All @@ -219,6 +232,7 @@ def get_round():


@app.route("/start_session", methods=["GET", "POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new session.
return: The response from control.
Expand All @@ -229,6 +243,7 @@ def start_session():


@app.route("/list_sessions", methods=["GET"])
@jwt_auth_required(role="admin")
def list_sessions():
"""Get all sessions from the statestore.
return: All sessions as a json object.
Expand All @@ -241,6 +256,7 @@ def list_sessions():


@app.route("/get_session", methods=["GET"])
@jwt_auth_required(role="admin")
def get_session():
"""Get a session from the statestore.
param: session_id: The session id to get.
Expand All @@ -258,12 +274,14 @@ def get_session():


@app.route("/set_active_package", methods=["PUT"])
@jwt_auth_required(role="admin")
def set_active_package():
id = request.args.get("id", None)
return api.set_active_compute_package(id)


@app.route("/set_package", methods=["POST"])
@jwt_auth_required(role="admin")
def set_package():
""" Set the compute package in the statestore.
Usage with curl:
Expand Down Expand Up @@ -296,6 +314,7 @@ def set_package():


@app.route("/get_package", methods=["GET"])
@jwt_auth_required(role="admin")
def get_package():
"""Get the compute package from the statestore.
return: The compute package as a json object.
Expand All @@ -305,6 +324,7 @@ def get_package():


@app.route("/list_compute_packages", methods=["GET"])
@jwt_auth_required(role="admin")
def list_compute_packages():
"""Get the compute package from the statestore.
return: The compute package as a json object.
Expand All @@ -321,6 +341,7 @@ def list_compute_packages():


@app.route("/download_package", methods=["GET"])
@jwt_auth_required(role="client")
def download_package():
"""Download the compute package.
return: The compute package as a json object.
Expand All @@ -331,12 +352,14 @@ def download_package():


@app.route("/get_package_checksum", methods=["GET"])
@jwt_auth_required(role="admin")
def get_package_checksum():
name = request.args.get("name", None)
return api.get_checksum(name)


@app.route("/get_latest_model", methods=["GET"])
@jwt_auth_required(role="admin")
def get_latest_model():
"""Get the latest model from the statestore.
return: The initial model as a json object.
Expand All @@ -346,6 +369,7 @@ def get_latest_model():


@app.route("/set_current_model", methods=["PUT"])
@jwt_auth_required(role="admin")
def set_current_model():
"""Set the initial model in the statestore and upload to model repository.
Usage with curl:
Expand All @@ -368,6 +392,7 @@ def set_current_model():


@app.route("/get_initial_model", methods=["GET"])
@jwt_auth_required(role="admin")
def get_initial_model():
"""Get the initial model from the statestore.
return: The initial model as a json object.
Expand All @@ -377,6 +402,7 @@ def get_initial_model():


@app.route("/set_initial_model", methods=["POST"])
@jwt_auth_required(role="admin")
def set_initial_model():
"""Set the initial model in the statestore and upload to model repository.
Usage with curl:
Expand All @@ -397,6 +423,7 @@ def set_initial_model():


@app.route("/get_controller_status", methods=["GET"])
@jwt_auth_required(role="admin")
def get_controller_status():
"""Get the status of the controller.
return: The status as a json object.
Expand All @@ -406,6 +433,7 @@ def get_controller_status():


@app.route("/get_client_config", methods=["GET"])
@jwt_auth_required(role="admin")
def get_client_config():
"""Get the client configuration.
return: The client configuration as a json object.
Expand All @@ -416,6 +444,7 @@ def get_client_config():


@app.route("/get_events", methods=["GET"])
@jwt_auth_required(role="admin")
def get_events():
"""Get the events from the statestore.
return: The events as a json object.
Expand All @@ -428,6 +457,7 @@ def get_events():


@app.route("/list_validations", methods=["GET"])
@jwt_auth_required(role="admin")
def list_validations():
"""Get all validations from the statestore.
return: All validations as a json object.
Expand All @@ -439,6 +469,7 @@ def list_validations():


@app.route("/add_combiner", methods=["POST"])
@jwt_auth_required(role="combiner")
def add_combiner():
"""Add a combiner to the network.
return: The response from the statestore.
Expand All @@ -454,6 +485,7 @@ def add_combiner():


@app.route("/add_client", methods=["POST"])
@jwt_auth_required(role="client")
def add_client():
"""Add a client to the network.
return: The response from control.
Expand All @@ -470,6 +502,7 @@ def add_client():


@app.route("/list_combiners_data", methods=["POST"])
@jwt_auth_required(role="admin")
def list_combiners_data():
"""List data from combiners.
return: The response from control.
Expand All @@ -489,6 +522,7 @@ def list_combiners_data():


@app.route("/get_plot_data", methods=["GET"])
@jwt_auth_required(role="admin")
def get_plot_data():
"""Get plot data from the statestore.
rtype: json
Expand Down
6 changes: 6 additions & 0 deletions fedn/fedn/network/api/v1/client_routes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask import Blueprint, jsonify, request

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs,
get_typed_list_headers, mdb)
from fedn.network.storage.statestore.stores.client_store import ClientStore
Expand All @@ -11,6 +12,7 @@


@bp.route("/", methods=["GET"])
@jwt_auth_required(role="admin")
def get_clients():
"""Get clients
Retrieves a list of clients based on the provided parameters.
Expand Down Expand Up @@ -127,6 +129,7 @@ def get_clients():


@bp.route("/list", methods=["POST"])
@jwt_auth_required(role="admin")
def list_clients():
"""List clients
Retrieves a list of clients based on the provided parameters.
Expand Down Expand Up @@ -213,6 +216,7 @@ def list_clients():


@bp.route("/count", methods=["GET"])
@jwt_auth_required(role="admin")
def get_clients_count():
"""Clients count
Retrieves the total number of clients based on the provided parameters.
Expand Down Expand Up @@ -273,6 +277,7 @@ def get_clients_count():


@bp.route("/count", methods=["POST"])
@jwt_auth_required(role="admin")
def clients_count():
"""Clients count
Retrieves the total number of clients based on the provided parameters.
Expand Down Expand Up @@ -325,6 +330,7 @@ def clients_count():


@bp.route("/<string:id>", methods=["GET"])
@jwt_auth_required(role="admin")
def get_client(id: str):
"""Get client
Retrieves a client based on the provided id.
Expand Down
Loading

0 comments on commit 250971c

Please sign in to comment.