Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-1458 | Add start cmd to session #830

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 111 additions & 13 deletions fedn/cli/session_cmd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import click
import requests

from .main import main
from .shared import CONTROLLER_DEFAULTS, get_response, print_response
from .shared import CONTROLLER_DEFAULTS, get_api_url, get_response, get_token, print_response


@main.group("session")
@click.pass_context
def session_cmd(ctx):
""":param ctx:"""
"""Session commands."""
pass


Expand All @@ -19,12 +20,7 @@ def session_cmd(ctx):
@session_cmd.command("list")
@click.pass_context
def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None):
"""Return:
------
- count: number of sessions
- result: list of sessions

"""
"""List sessions."""
headers = {}

if n_max:
Expand All @@ -42,10 +38,112 @@ def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n
@session_cmd.command("get")
@click.pass_context
def get_session(ctx, protocol: str, host: str, port: str, token: str = None, id: str = None):
"""Return:
------
- result: session with given session id

"""
"""Get session by id."""
response = get_response(protocol=protocol, host=host, port=port, endpoint=f"sessions/{id}", token=token, headers={}, usr_api=False, usr_token=False)
print_response(response, "session", id)


@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)")
@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)")
@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)")
@click.option("-t", "--token", required=False, help="Authentication token")
@click.option("-n", "--name", required=False, help="Name of the session")
@click.option("-a", "--aggregator", required=False, default="fedavg", help="The aggregator plugin to use")
@click.option("-ak", "--aggregator_kwargs", required=False, type=dict, help="Aggregator keyword arguments")
@click.option("-m", "--model_id", required=False, help="The id of the initial model")
@click.option("-rt", "--round_timeout", required=False, default=180, type=int, help="The round timeout to use in seconds")
@click.option("-r", "--rounds", required=False, default=5, type=int, help="The number of rounds to perform")
@click.option("-rb", "--round_buffer_size", required=False, default=-1, type=int, help="The round buffer size to use")
@click.option("-d", "--delete_models", required=False, default=True, type=bool, help="Whether to delete models after each round at combiner (save storage)")
@click.option("-v", "--validate", required=False, default=True, type=bool, help="Whether to validate the model after each round")
@click.option("-hp", "--helper", required=False, help="The helper type to use")
@click.option("-mc", "--min_clients", required=False, default=1, type=int, help="The minimum number of clients required")
@click.option("-rc", "--requested_clients", required=False, default=8, type=int, help="The requested number of clients")
@session_cmd.command("start")
@click.pass_context
def start_session(
ctx,
protocol: str,
host: str,
port: str,
token: str,
name: str = None,
aggregator: str = "fedavg",
aggregator_kwargs: dict = None,
model_id: str = None,
round_timeout: int = 180,
rounds: int = 5,
round_buffer_size: int = -1,
delete_models: bool = True,
validate: bool = True,
helper: str = None,
min_clients: int = 1,
requested_clients: int = 8,
):
"""Start a new session."""
headers = {}
_token = get_token(token=token, usr_token=False)
if _token:
headers = {"Authorization": _token}

if model_id is None:
url = get_api_url(protocol, host, port, "models/active", usr_api=False)
response = requests.get(url, headers=headers)
if response.status_code == 200:
model_id = response.json()
else:
click.secho(f"Failed to get active model: {response.json()}", fg="red")
return

if helper is None:
url = get_api_url(protocol, host, port, "helpers/active", usr_api=False)
response = requests.get(url, headers=headers)
if response.status_code == 400:
helper = "numpyhelper"
elif response.status_code == 200:
helper = response.json()
else:
click.secho("An unexpected error occurred when getting the active helper", fg="red")
return

url = get_api_url(protocol, host, port, "sessions/", usr_api=False)
response = requests.post(
url,
json={
"name": name,
"session_config": {
"aggregator": aggregator,
"aggregator_kwargs": aggregator_kwargs,
"round_timeout": round_timeout,
"buffer_size": round_buffer_size,
"model_id": model_id,
"delete_models_storage": delete_models,
"clients_required": min_clients,
"requested_clients": requested_clients,
"validate": validate,
"helper_type": helper,
"server_functions": None,
},
},
headers=headers,
verify=False,
)

if response.status_code == 201:
session_id = response.json()["session_id"]
url = get_api_url(protocol, host, port, "sessions/start", usr_api=False)
response = requests.post(
url,
json={
"session_id": session_id,
"rounds": rounds,
"round_timeout": round_timeout,
},
headers=headers,
verify=False,
)
response_json = response.json()
response_json["session_id"] = session_id
click.secho(f"Session started successfully: {response_json}", fg="green")
else:
click.secho(f"Failed to start session: {response.json()}", fg="red")
Loading