diff --git a/.gitignore b/.gitignore index 8b8fb107a..2c2930785 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ augur_export_env.sh config.yml reports.yml *.pid +*.sock node_modules/ .idea/ diff --git a/augur/api/view/init.py b/augur/api/view/init.py index b0b4b2744..37b1972e8 100644 --- a/augur/api/view/init.py +++ b/augur/api/view/init.py @@ -91,4 +91,4 @@ def write_settings(current_settings): # Initialize logging def init_logging(): global logger - logger = AugurLogger("augur_view", reset_logfiles=True).get_logger() + logger = AugurLogger("augur_view", reset_logfiles=False).get_logger() diff --git a/augur/application/cli/__init__.py b/augur/application/cli/__init__.py index f15758c9c..00f41a553 100644 --- a/augur/application/cli/__init__.py +++ b/augur/application/cli/__init__.py @@ -32,7 +32,7 @@ def new_func(ctx, *args, **kwargs): You are not connected to the internet.\n \ Please connect to the internet to run Augur\n \ Consider setting http_proxy variables for limited access installations.") - sys.exit() + sys.exit(-1) return update_wrapper(new_func, function_internet_connection) @@ -78,7 +78,7 @@ def new_func(ctx, *args, **kwargs): print(f"\n\n{usage} command setup failed\nERROR: connecting to database\nHINT: The {incorrect_values} may be incorrectly specified in {location}\n") engine.dispose() - sys.exit() + sys.exit(-2) return update_wrapper(new_func, function_db_connection) diff --git a/augur/application/cli/_multicommand.py b/augur/application/cli/_multicommand.py index 2a1bfd1c7..19392b274 100644 --- a/augur/application/cli/_multicommand.py +++ b/augur/application/cli/_multicommand.py @@ -30,7 +30,6 @@ def get_command(self, ctx, name): # Check that the command exists before importing if not cmdfile.is_file(): - return # Prefer to raise exception instead of silcencing it diff --git a/augur/application/cli/api.py b/augur/application/cli/api.py index d716957c0..50044de7c 100644 --- a/augur/application/cli/api.py +++ b/augur/application/cli/api.py @@ -14,15 +14,16 @@ from augur.application.db.session import DatabaseSession from augur.application.logs import AugurLogger -from augur.application.cli import test_connection, test_db_connection, with_database +from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext from augur.application.cli._cli_util import _broadcast_signal_to_processes, raise_open_file_limit, clear_redis_caches, clear_rabbitmq_messages from augur.application.db.lib import get_value -logger = AugurLogger("augur", reset_logfiles=True).get_logger() +logger = AugurLogger("augur", reset_logfiles=False).get_logger() @click.group('api', short_help='Commands for controlling the backend API server') -def cli(): - pass +@click.pass_context +def cli(ctx): + ctx.obj = DatabaseContext() @cli.command("start") @click.option("--development", is_flag=True, default=False, help="Enable development mode") diff --git a/augur/application/cli/backend.py b/augur/application/cli/backend.py index 37669725a..f470675d1 100644 --- a/augur/application/cli/backend.py +++ b/augur/application/cli/backend.py @@ -47,8 +47,8 @@ def cli(ctx): @click.pass_context def start(ctx, disable_collection, development, pidfile, port): """Start Augur's backend server.""" - with open(pidfile, "w") as pidfile: - pidfile.write(str(os.getpid())) + with open(pidfile, "w") as pidfile_io: + pidfile_io.write(str(os.getpid())) try: if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": @@ -63,6 +63,8 @@ def start(ctx, disable_collection, development, pidfile, port): if development: os.environ["AUGUR_DEV"] = "1" logger.info("Starting in development mode") + + os.environ["AUGUR_PIDFILE"] = pidfile try: gunicorn_location = os.getcwd() + "/augur/api/gunicorn_conf.py" @@ -74,6 +76,11 @@ def start(ctx, disable_collection, development, pidfile, port): if not port: port = get_value("Server", "port") + os.environ["AUGUR_PORT"] = str(port) + + if disable_collection: + os.environ["AUGUR_DISABLE_COLLECTION"] = "1" + worker_vmem_cap = get_value("Celery", 'worker_process_vmem_cap') gunicorn_command = f"gunicorn -c {gunicorn_location} -b {host}:{port} augur.api.server:app --log-file gunicorn.log" @@ -128,7 +135,7 @@ def start(ctx, disable_collection, development, pidfile, port): augur_collection_monitor.si().apply_async() else: - logger.info("Collection disabled") + logger.info("Collection disabled") try: server.wait() @@ -153,6 +160,8 @@ def start(ctx, disable_collection, development, pidfile, port): cleanup_after_collection_halt(logger, ctx.obj.engine) except RedisConnectionError: pass + + os.unlink(pidfile) def start_celery_worker_processes(vmem_cap_ratio, disable_collection=False): @@ -224,6 +233,54 @@ def stop(ctx): augur_stop(signal.SIGTERM, logger, ctx.obj.engine) +@cli.command('stop-collection-blocking') +@test_connection +@test_db_connection +@with_database +@click.pass_context +def stop_collection(ctx): + """ + Stop collection tasks if they are running, block until complete + """ + processes = get_augur_processes() + + stopped = [] + + p: psutil.Process + for p in processes: + if p.name() == "celery": + stopped.append(p) + p.terminate() + + if not len(stopped): + logger.info("No collection processes found") + return + + _, alive = psutil.wait_procs(stopped, 5, + lambda p: logger.info(f"STOPPED: {p.pid}")) + + killed = [] + while True: + for i in range(len(alive)): + if alive[i].status() == psutil.STATUS_ZOMBIE: + logger.info(f"KILLING ZOMBIE: {alive[i].pid}") + alive[i].kill() + killed.append(i) + elif not alive[i].is_running(): + logger.info(f"STOPPED: {p.pid}") + killed.append(i) + + for i in reversed(killed): + alive.pop(i) + + if not len(alive): + break + + logger.info(f"Waiting on [{', '.join(str(p.pid for p in alive))}]") + time.sleep(0.5) + + cleanup_after_collection_halt(logger, ctx.obj.engine) + @cli.command('kill') @test_connection @test_db_connection @@ -388,7 +445,7 @@ def processes(): Outputs the name/PID of all Augur server & worker processes""" augur_processes = get_augur_processes() for process in augur_processes: - logger.info(f"Found process {process.pid}") + logger.info(f"Found process {process.pid} [{process.name()}] -> Parent: {process.parent().pid}") def get_augur_processes(): augur_processes = [] diff --git a/augur/application/cli/collection.py b/augur/application/cli/collection.py index d4daaf95b..84bbd5cba 100644 --- a/augur/application/cli/collection.py +++ b/augur/application/cli/collection.py @@ -22,14 +22,15 @@ from augur.application.db.session import DatabaseSession from augur.application.logs import AugurLogger from augur.application.db.lib import get_value -from augur.application.cli import test_connection, test_db_connection, with_database +from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext from augur.application.cli._cli_util import _broadcast_signal_to_processes, raise_open_file_limit, clear_redis_caches, clear_rabbitmq_messages -logger = AugurLogger("augur", reset_logfiles=True).get_logger() +logger = AugurLogger("augur", reset_logfiles=False).get_logger() @click.group('server', short_help='Commands for controlling the backend API server & data collection workers') -def cli(): - pass +@click.pass_context +def cli(ctx): + ctx.obj = DatabaseContext() @cli.command("start") @click.option("--development", is_flag=True, default=False, help="Enable development mode, implies --disable-collection") diff --git a/augur/application/cli/jumpstart.py b/augur/application/cli/jumpstart.py new file mode 100644 index 000000000..b65255ec1 --- /dev/null +++ b/augur/application/cli/jumpstart.py @@ -0,0 +1,98 @@ +import psutil +import click +import time +import subprocess +from pathlib import Path +from datetime import datetime + +@click.group(invoke_without_command=True) +@click.pass_context +def cli(ctx): + if ctx.invoked_subcommand is None: + p = check_running() + if not p: + click.echo("Jumpstart is not running. Start it with: augur jumpstart run") + return + + click.echo(f"Connecting to Jumpstart: [{p.pid}]") + + while p.is_running() and not len(p.connections("unix")): + # Waiting for app to open fd socket + time.sleep(0.1) + + if not p.is_running(): + click.echo("Error: Jumpstart server exited abnormally") + return + + from jumpstart.tui import run_app + run_app(ctx=ctx) + +def check_running(pidfile = ".jumpstart.pid") -> psutil.Process: + jumpidf = Path(pidfile) + + try: + jumpid, create_time = jumpidf.read_text().splitlines() + jumpp = psutil.Process(int(jumpid)) + + if create_time != str(jumpp.create_time()): + # PID was reused, not the original + jumpidf.unlink() + return + + return jumpp + except (psutil.NoSuchProcess, FileNotFoundError): + return + except PermissionError: + click.echo(f"Permission denied while reading from or writing to pidfile [{str(jumpidf.resolve())}]") + +@cli.command("status") +def get_status(): + p = check_running() + + if not p: + click.echo("Jumpstart is not running") + else: + since = datetime.fromtimestamp(p.create_time()).astimezone() + delta = datetime.now().astimezone() - since + click.echo(f"Jumpstart is running at: [{p.pid}] since {since.strftime('%a %b %d, %Y %H:%M:%S %z:%Z')} [{delta}]") + +@cli.command("run") +@click.pass_context +def startup(ctx): + p = check_running() + + if not p: + click.echo("Starting") + p = launch(ctx) + else: + click.echo(f"Jumpstart is already running [{p.pid}]") + +@cli.command("processID") +def get_main_ID(): + p = check_running() + + if p: + click.echo(p.pid) + +@cli.command("shutdown") +def shutdown_server(): + p = check_running() + + if not p: + click.echo("Jumpstart is not running") + return + + click.echo("Blocking on shutdown") + p.terminate() + p.wait() + +def launch(ctx, pidfile = ".jumpstart.pid", socketfile = "jumpstart.sock"): + service = subprocess.Popen(f"python -m jumpstart.jumpstart pidfile={pidfile} socketfile={socketfile}".split()) + + # Popen object does not have create_time for some reason + ext_process = psutil.Process(service.pid) + + with open(pidfile, "w") as file: + file.write(f"{ext_process.pid}\n{ext_process.create_time()}") + + return ext_process diff --git a/augur/application/cli/tasks.py b/augur/application/cli/tasks.py index f99e078b6..f760dfdde 100644 --- a/augur/application/cli/tasks.py +++ b/augur/application/cli/tasks.py @@ -17,7 +17,7 @@ from augur.application.cli import test_connection, test_db_connection from augur.application.cli.backend import clear_rabbitmq_messages, raise_open_file_limit -logger = AugurLogger("augur", reset_logfiles=True).get_logger() +logger = AugurLogger("augur", reset_logfiles=False).get_logger() @click.group('celery', short_help='Commands for controlling the backend API server & data collection workers') def cli(): diff --git a/augur/tasks/github/events.py b/augur/tasks/github/events.py index 00789a342..94bfb3250 100644 --- a/augur/tasks/github/events.py +++ b/augur/tasks/github/events.py @@ -155,6 +155,7 @@ def _process_issue_events(self, issue_events, repo_id): issue_event_dicts = [] contributors = [] + issue_url_to_id_map = self._get_map_from_issue_url_to_id(repo_id) for event in issue_events: @@ -199,6 +200,7 @@ def _process_pr_events(self, pr_events, repo_id): try: pull_request_id = pr_url_to_id_map[pr_url] except KeyError: + self._logger.warning(f"{self.repo_identifier} - {self.task_name}: Could not find related pr. We were searching for: {pr_url}") continue @@ -281,6 +283,7 @@ def _collect_and_process_issue_events(self, owner, repo, repo_id, key_auth): event_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/events" try: + for event in github_data_access.paginate_resource(event_url): event, contributor = self._process_github_event_contributors(event) diff --git a/jumpstart/API.py b/jumpstart/API.py new file mode 100644 index 000000000..1020a8043 --- /dev/null +++ b/jumpstart/API.py @@ -0,0 +1,167 @@ +from enum import Enum, EnumMeta, auto + +class Component(Enum): + all = auto() + frontend = auto() + api = auto() + collection = auto() + + @staticmethod + def from_str(s): + if isinstance(s, Component): + return s + + try: + return Component[s] + except: + return None + +class Status(Enum): + error="E" + terminated="T" + information="I" + status="S" + + def __call__(self, msg = None): + response = {"status": self.value} + + if self == Status.error: + response.update({ + "detail": msg or "unspecified" + }) + elif self == Status.terminated: + if msg: + response.update({ + "reason": msg + }) + elif self == Status.information: + response.update({ + "detail": msg or "ack" + }) + elif self == Status.status: + response.update(msg) + + return response + +class Command(Enum): + status=auto() + start=auto() + stop=auto() + restart=auto() + shutdown=auto() + unknown=auto() + + @staticmethod + def of(msg: dict): + cmd = msg.pop("cmd") + + try: + return Command[cmd] + except KeyError: + raise Exception(f"Unknown command: [{cmd}]") + +spec = { + "commands": [ + { + "name": "status", + "desc": "Display the current status of Augur processes", + "args": [] + }, { + "name": "start", + "desc": "Start one or more components of Augur", + "args": [ + { + "name": "component", + "required": True, + "type": Component + }, { + "name": "options", + "required": False, + "type": "list" + } + ] + }, { + "name": "stop", + "desc": "Stop one or more components of Augur", + "args": [ + { + "name": "component", + "required": True, + "type": Component + } + ] + }, { + "name": "restart", + "desc": "restart one or more components of Augur", + "args": [ + { + "name": "component", + "required": True, + "type": Component + } + ] + }, { + "name": "shutdown", + "desc": "Stop all Augur components and shut down Jumpstart", + "args": [] + } + ], + "statuses": [ + { + "ID": "E", + "desc": "An error occurred", + "fields": [ + { + "key": "detail", + "desc": "A detail message about the error", + "required": True + } + ] + }, { + "ID": "T", + "desc": "Connection terminated", + "fields": [ + { + "key": "reason", + "desc": "A message describing the reason for disconnection", + "required": False + } + ] + }, { + "ID": "I", + "desc": "Information from server", + "fields": [ + { + "key": "detail", + "desc": "An informational message from the jumpstart server", + "required": True + } + ] + }, { + "ID": "S", + "desc": "Status of Augur components", + "fields": [ + { + "key": "frontend", + "desc": "The frontend status", + "required": True + }, { + "key": "api", + "desc": "The API status", + "required": True + }, { + "key": "collection", + "desc": "The collection status", + "required": True + } + ] + } + ] +} + +for command in spec["commands"]: + for arg in command["args"]: + if issubclass(type(arg["type"]), (Enum, EnumMeta)): + t = arg["type"] + arg["type"] = "enum" + arg["values"] = [c.name for c in t] diff --git a/jumpstart/Client.py b/jumpstart/Client.py new file mode 100644 index 000000000..9a1e6280a --- /dev/null +++ b/jumpstart/Client.py @@ -0,0 +1,73 @@ +import json +import socket +import threading + +from .Logging import console +from .API import spec, Status, Command +from .utils import synchronized, CallbackCollection + +class JumpstartClient: + def __init__(self, sock: socket.socket, wake_lock: threading.Lock, ID: int, callbacks: CallbackCollection): + self.socket = sock + self.io = socket.SocketIO(sock, "rw") + self.lock = wake_lock + self.ID = ID + self.respond(spec) + self.thread = threading.Thread(target=self.loop, + args=[callbacks], + name=f"client_{ID}") + self.thread.start() + + def loop(self, cbs): + while line := self.io.readline().decode(): + try: + body = json.loads(line) + + if not "cmd" in body: + self.respond(Status.error("Command unspecified")) + + cmd = Command.of(body) + if cmd == Command.status: + status_dict = cbs.status() + self.respond(Status.status(status_dict)) + if cmd == Command.shutdown: + cbs.shutdown(self) + if cmd == Command.start: + cbs.start(body["component"], self, *body.get("options", [])) + if cmd == Command.stop: + cbs.stop(body["component"], self) + except json.JSONDecodeError: + self.respond(Status.error("Invalid JSON")) + except Exception as e: + self.respond(Status.error(str(e))) + console.exception("Exception while handling request: " + line) + + console.info(f"Disconnect") + cbs.disconnect(self) + + @synchronized + def send(self, *args, **kwargs): + if args and kwargs: + kwargs["args"] = args + + if self.io.closed: + return + + self.io.write((json.dumps(kwargs) + "\n").encode()) + self.io.flush() + + @synchronized + def respond(self, msg: Status): + if self.io.closed: + return + # console.info(msg) + self.io.write((json.dumps(msg) + "\n").encode()) + self.io.flush() + + @synchronized + def close(self, **kwargs): + self.send(status="T", **kwargs) + self.io.close() + + if not self.thread is threading.currentThread(): + self.thread.join() diff --git a/jumpstart/Logging.py b/jumpstart/Logging.py new file mode 100644 index 000000000..574a6b52f --- /dev/null +++ b/jumpstart/Logging.py @@ -0,0 +1,25 @@ +import logging +from pathlib import Path + +def init_logging(logger = logging.Logger("jumpstart") , errlog_file = Path("logs/jumpstart.error"), stdout_file = Path("logs/jumpstart.log")) -> logging.Logger: + errlog = logging.FileHandler(errlog_file, "w") + stdout = logging.FileHandler(stdout_file, "w") + + formatter = logging.Formatter("[%(asctime)s] [%(name)s] [%(process)d]->[%(threadName)s] [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S %z") + + errlog.setLevel(logging.WARN) + stdout.setLevel(logging.INFO) + stdout.addFilter(lambda entry: entry.levelno < logging.WARN) + errlog.formatter = stdout.formatter = formatter + + logger.addHandler(errlog) + logger.addHandler(stdout) + logger.setLevel(logging.INFO) + + global console + console = logger + + return logger + +if "console" not in globals(): + init_logging() diff --git a/jumpstart/Server.py b/jumpstart/Server.py new file mode 100644 index 000000000..0cca3a87d --- /dev/null +++ b/jumpstart/Server.py @@ -0,0 +1,76 @@ +import socket +import threading +from pathlib import Path + +from .API import Status +from .Logging import console +from .Client import JumpstartClient as Client +from .utils import synchronized, CallbackCollection + +class JumpstartServer: + def __init__(self, callbacks: CallbackCollection, + socketfile = Path("jumpstart.sock").resolve(), + input_lock = threading.Lock()): + try: + socketfile.unlink(True) + except: + console.critical(f"socket in use: {socketfile}") + exit(1) + + callbacks.register_all(disconnect=self._remove_client) + + self.socketfile = socketfile + self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.server.bind(str(socketfile)) + self.server.setblocking(True) + self.lock = input_lock + callbacks.register("shutdown", self._shutdown_callback) + self.callbacks = callbacks + + self.clients: list[Client] = [] + self.messages = [] + + self.loop = threading.Thread(target=self.accept_loop, name="console") + + def accept_loop(self): + self.server.listen() + + try: + while request := self.server.accept(): + conn, addr = request + console.info(f"Accepted client: {len(self.clients)}") + self.clients.append(Client(conn, self.lock, len(self.clients), callbacks=self.callbacks)) + except OSError: + # OSError is thrown when the socket is closed while blocking on accept + console.info("Server no longer accepting connections") + except Exception as e: + console.error("An exception occurred during accept") + console.error(str(e)) + + @synchronized + def _remove_client(self, client): + self.clients.remove(client) + + @synchronized + def _shutdown_callback(self, client: Client): + console.info(f"Server shutdown requested by {client.thread.name}") + self.close() + + @synchronized + def broadcast(self, msg, level=Status.information): + for client in self.clients: + client.respond(level(f"BROADCAST: {msg}")) + + def start(self): + self.loop.start() + + def closed(self): + return not self.loop.is_alive() + + def close(self): + self.server.shutdown(socket.SHUT_RDWR) + self.server.close() + self.socketfile.unlink() + + for client in self.clients: + client.close(reason="Server shutting down") diff --git a/jumpstart/jumpstart.py b/jumpstart/jumpstart.py new file mode 100644 index 000000000..2474dac02 --- /dev/null +++ b/jumpstart/jumpstart.py @@ -0,0 +1,57 @@ +import os +import sys +import time +import json +import socket +import signal +import logging +import threading + +from pathlib import Path +from subprocess import Popen, PIPE, STDOUT + +from .API import Status +from .API import Component +from .Logging import console +from .procman import ProcessManager +from .Server import JumpstartServer +from .utils import CallbackCollection +from .utils import UniversalPlaceholder + +global server, manager + +def handle_terminate(*args, **kwargs): + console.info("shutting down") + + manager.stop(Component.all, UniversalPlaceholder()) + server.close() + + exit(0) + +if __name__ == "__main__": + signal.signal(signal.SIGTERM, handle_terminate) + signal.signal(signal.SIGINT, handle_terminate) + threading.current_thread().setName("main") + + manager = ProcessManager() + + callbacks = CallbackCollection(start=manager.start, stop=manager.stop, status=manager.status) + server = JumpstartServer(callbacks) + server.start() + + while not server.closed(): + try: + manager.refresh() + except: + console.exception("Exception while refreshing status") + server.broadcast("Exception while refreshing status, going down", Status.error) + break + + if server.lock.acquire(True, 0.1): + # The input thread has notified us of a new message + server.lock.release() + pass + else: + time.sleep(0.1) + + handle_terminate() diff --git a/jumpstart/procman.py b/jumpstart/procman.py new file mode 100644 index 000000000..b15853161 --- /dev/null +++ b/jumpstart/procman.py @@ -0,0 +1,123 @@ +from .API import Status +from .API import Component +from .Logging import console +from .utils import synchronized +from .utils import UniversalPlaceholder +from .Client import JumpstartClient as Client + +import signal + +from time import sleep +from typing import Union +from pathlib import Path +from threading import Lock +from subprocess import Popen, PIPE, STDOUT, run + +class ProcessManager: + def __init__(self): + self.frontend = False + self._frontend = None + self.collection = False + self._collection = None + + # Ensure a read of the component status cannot happen during bringup + self._startlock = Lock() + + self.frontend_stdout = Path("logs/jumpstart_frontend.info") + self.frontend_stderr = Path("logs/jumpstart_frontend.error") + self.collection_stdout = Path("logs/jumpstart_collection.info") + self.collection_stderr = Path("logs/jumpstart_collection.error") + + def status(self): + return { + "frontend": self.frontend, + "api": self.frontend, + "collection": self.collection + } + + @synchronized + def start(self, component: Union[Component, str], client: Client, *options): + if not (c := Component.from_str(component)): + client.respond(Status.error(f"Invalid component for start: {component}")) + return + + check_db = run("augur db test-connection".split()) + + if check_db.returncode != 0: + client.respond(Status.error(f"Could not communicate with the database: {check_db.returncode}")) + return + + if c in (Component.api, Component.frontend, Component.all): + if self.frontend: + client.respond(Status.information(f"The frontend/api is already running")) + else: + with(self._startlock): + self._frontend = { + "stdout": self.frontend_stdout.open("w"), + "stderr": self.frontend_stderr.open("w") + } + self._frontend["process"] = Popen("augur api start".split() + list(options), + stdout=self._frontend["stdout"], + stderr=self._frontend["stderr"]) + if c in (Component.collection, Component.all): + if self.collection: + client.respond(Status.information(f"The collection is already running")) + else: + with(self._startlock): + self._collection = { + "stdout": self.collection_stdout.open("w"), + "stderr": self.collection_stderr.open("w") + } + self._collection["process"] = Popen("augur collection start".split() + list(options), + stdout=self._collection["stdout"], + stderr=self._collection["stderr"]) + + @synchronized + def stop(self, component: Union[Component, str], client: Client): + if not (c := Component.from_str(component)): + client.respond(Status.error(f"Invalid component for stop: {component}")) + return + + if c in (Component.api, Component.frontend, Component.all): + if not self.frontend: + client.respond(Status.information("The frontend/api is not running")) + else: + self._frontend["process"].send_signal(signal.SIGINT) + run("augur api stop".split(), stderr=PIPE, stdout=PIPE) + if c in (Component.collection, Component.all): + if not self.collection: + client.respond(Status.information("The collection is not running")) + else: + self._collection["process"].send_signal(signal.SIGINT) + run("augur collection stop".split(), stderr=PIPE, stdout=PIPE) + + @synchronized + def shutdown(self): + self.stop(Component.all, UniversalPlaceholder()) + + while self.refresh(): + sleep(0.1) + + @synchronized + def refresh(self): + with(self._startlock): + if self._frontend is not None: + if self._frontend["process"].poll() is not None: + self.frontend = False + self._frontend["stderr"].close() + self._frontend["stdout"].close() + self._frontend = None + console.info("Frontend shut down") + else: + self.frontend = True + if self._collection is not None: + if self._collection["process"].poll() is not None: + self.collection = False + self._collection["stderr"].close() + self._collection["stdout"].close() + self._collection = None + console.info("Collection shut down") + else: + self.collection = True + + return self.frontend or self.collection \ No newline at end of file diff --git a/jumpstart/tui.py b/jumpstart/tui.py new file mode 100644 index 000000000..2a0931688 --- /dev/null +++ b/jumpstart/tui.py @@ -0,0 +1,392 @@ +import socket as sockets +from socket import socket + +import json +import asyncio +import logging +from sys import argv +from pathlib import Path +from datetime import datetime + +from textual import work +from textual.app import App, on +from textual.binding import Binding +from textual.message import Message +from textual.events import Load, Mount +from textual.containers import Horizontal, Vertical +from textual.widgets import RichLog, Input, Button, Label + +class JumpstartTUI(App): + # Setup + CSS_PATH = "tui.tcss" + + """ + The name of a binding's action automatically links + to a function of the same name, but prefixed with + "action_". So, by putting "kb_exit" here, I'm + telling the runtime to call a function named + "action_kb_exit" any time the user presses CTRL+D. + + I don't personally like that, but you'll find the + function defined below. + """ + BINDINGS = [ + Binding("ctrl+d", "kb_exit", "Exit", priority=True) + ] + + def compose(self): + # Yield the app's components on load + with Horizontal(): + with Vertical(classes="sidebar"): + with Vertical(classes="status_container"): + with Horizontal(classes="info_container"): + yield Label("Frontend") + yield Label("X", id="frontend_label", classes="status_label") + with Horizontal(classes="info_container"): + yield Label("API") + yield Label("X", id="api_label", classes="status_label") + with Horizontal(classes="info_container"): + yield Label("Collection") + yield Label("X", id="collection_label", classes="status_label") + + yield Button("Start", id="startbtn") + yield Button("Stop", id="stopbtn") + yield Button("Exit", variant="error", id="exitbtn") + + with Vertical(): + yield RichLog( + highlight=True, + auto_scroll=True, + wrap=True, + id="stdout" + ) + yield Input( + placeholder="Enter Command", + id="command_line", + valid_empty=False + ) + + class ServerMessage(Message): + """Add a message to the outgoing queue""" + + def __init__(self, **msg) -> None: + self.msg = msg + super().__init__() + + class DirtyExit(Message): + """Custom event to shut down app without immediately closing""" + + def __init__(self) -> None: + super().__init__() + + class Info(Message): + """Custom event to print an info message""" + + def __init__(self, msg) -> None: + self.msg = msg + super().__init__() + + def __str__(self): + return self.msg + + class Error(Message): + """Custom event to print an error message""" + + def __init__(self, msg) -> None: + self.msg = msg + super().__init__() + + def __str__(self): + return self.msg + + class Status(Message): + """Notify the UI of an updated status""" + + def __init__(self, frontend: bool, api: bool, collection: bool): + self.f = frontend + self.a = api + self.c = collection + super().__init__() + + def action_kb_exit(self, force = False): + """Called when the user requests to exit""" + inbox = self.query(Input).filter("#command_line").only_one() + value = inbox.value + + if not value or force: + self.server.shutdown(sockets.SHUT_RDWR) + self.server_IO.close() + self.app.exit() + + # Logging + @on(Error) + def error(self, msg): + self.out("ERROR", msg) + + @on(Info) + def info(self, msg): + self.out("INFO", msg) + + def out(self, level, msg, source = "console"): + time_str = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %z") + formatted = "[{}] [{}] [{}] {}".format(time_str, source, level, str(msg)) + self.raw(formatted) + + def raw(self, obj, disable_highlighting = False, enable_markup = False): + stdout = self.query_one(RichLog) + if disable_highlighting: + stdout.highlight = False + if enable_markup: + stdout.markup = True + stdout.write(obj) + if disable_highlighting: + stdout.highlight = True + if enable_markup: + stdout.markup = False + + def show_help(self, command = None, *args): + if not command: + self.raw(self.help_str, True, True) + else: + self.info("Granular help not currently available") + + # Events + @on(Mount) + async def startup_process_tasks(self): + # Make sure the cmd input takes focus + inbox = self.query(Input).filter("#command_line").only_one() + inbox.focus() + + self.server_rec() + self.status_ping() + + @on(DirtyExit) + def exit_dirty(self): + self.workers.cancel_all() + inbox = self.query(Input).filter("#command_line").only_one() + inbox.value = "" + inbox.disabled = True + + for btn in self.query(Button).results(): + if not btn.id == "exitbtn": + btn.disabled = True + + self.error("Use CTRL+D or CTRL+C to exit") + + @on(Button.Pressed, "#exitbtn") + def handle_btn_exit(self): + self.action_kb_exit(True) + + @on(Input.Submitted, "#command_line") + def accept_input(self, event: Input.Submitted): + value = event.value + + # The input box only accepts non-empty strings + cmd, *args = value.split() + + if cmd == "exit": + self.action_kb_exit(True) + + self.out("INFO", value, "user") + event.input.value = "" + + if cmd == "help": + self.show_help(*args) + + packet = {} + + for command in self.spec["commands"]: + if command["name"] == cmd: + packet["cmd"] = cmd + for arg in command["args"]: + if arg["required"] and not len(args): + self.error(f'Missing required positional argument: [{arg["name"]}]') + return + elif len(args): + if arg["type"] == "enum": + if args[0] not in arg["values"]: + self.error(f'Invalid value for arg [{arg["name"]}]') + self.error(f'> Must be one of <{", ".join(arg["values"])}>') + return + packet[arg["name"]] = args.pop(0) + elif arg["type"] == "list": + packet[arg["name"]] = args + break + break + if packet: + self.send(**packet) + else: + self.error("Unknown command") + + + @on(Button.Pressed, "#statusbtn") + def get_status(self): + self.send(cmd = "status") + + @on(ServerMessage) + def server_send(self, carrier: ServerMessage): + self.send(**carrier.msg) + + @on(Status) + def status_update(self, status): + labels = self.query(".status_label") + + for label in labels: + if label.id == "frontend_label": + if status.f: + label.add_class("label_up") + else: + label.remove_class("label_up") + elif label.id == "api_label": + if status.a: + label.add_class("label_up") + else: + label.remove_class("label_up") + elif label.id == "collection_label": + if status.c: + label.add_class("label_up") + else: + label.remove_class("label_up") + + # Server communication + def send(self, **kwargs): + try: + self.server_IO.write((json.dumps(kwargs) + "\n").encode()) + self.server_IO.flush() + except Exception: + self.post_message(self.DirtyExit()) + + def sendraw(self, msg: str): + self.server_IO.write((msg + "\n").encode()) + self.server_IO.flush() + + # Worker tasks + @work(name="status") + async def status_ping(self): + while not hasattr(self, "server_IO"): + # Wait for server connection + await asyncio.sleep(0.1) + + await asyncio.sleep(1) + + while not self.server_IO.closed: + self.post_message(self.ServerMessage(cmd = "status")) + await asyncio.sleep(1) + + @work(name="server", thread=True) + def server_rec(self): + stdout = self.query_one(RichLog) + + def log(obj): + self.call_from_thread(stdout.write, obj) + + def info(msg, source = "server"): + time_str = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %z") + formatted = "[{}] [{}] [INFO] {}".format(time_str, source, msg) + log(formatted) + + def error(msg, source = "server"): + time_str = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %z") + formatted = "[{}] [{}] [ERROR] {}".format(time_str, source, msg) + log(formatted) + + self.server = socket(sockets.AF_UNIX, sockets.SOCK_STREAM) + try: + self.server.connect(self.socket_file) + self.server.setblocking(True) + io = sockets.SocketIO(self.server, "rw") + self.server_IO = io + except Exception as e: + error(f"Could not connect to the jumpstart server [{self.socket_file}]") + log(e) + self.post_message(self.DirtyExit()) + return + + # Jumpstart sends the spec as a dict on connection + try: + init = io.readline() + self.spec = json.loads(init) + except json.JSONDecodeError as e: + error("Could not decode handshake", "console") + log(init.decode()) + log(e) + self.post_message(self.DirtyExit()) + return + + info("Connected") + + help_str = "Command Reference:\n" + + for command in self.spec["commands"]: + help_str += f'\t> [bold yellow]{command["name"]}' + for arg in command["args"]: + if arg["required"]: + arg_str = f' {arg["name"]} {{}}' + else: + arg_str = f' \\[{arg["name"]} {{}}]' + + if arg["type"] == "enum": + help_str += arg_str.format(f'<{", ".join(arg["values"])}>') + elif arg["type"] == "list": + help_str += arg_str.format("...") + help_str += "[/bold yellow]\n" + help_str += f'\t\t [green]{command["desc"]}[/green]\n' + + self.help_str = help_str + + while not io.closed: + line = io.readline() + + try: + msg = json.loads(line) + if msg["status"] == "E": + error(msg["detail"]) + elif msg["status"] == "T": + info("Connection closed: " + msg.get("reason") or "Reason unspecified") + self.post_message(self.DirtyExit()) + return + elif msg["status"] == "I": + info(msg["detail"]) + elif msg["status"] == "S": + self.post_message(self.Status( + msg["frontend"], + msg["api"], + msg["collection"] + )) + except json.JSONDecodeError: + error("Bad packet from server", "console") + log(line) + except Exception as e: + error("Exception while reacting to incoming packet", "console") + log(e) + log(msg) + +def run_app(socket_file = Path("jumpstart.sock"), ctx = None): + app = JumpstartTUI() + + if not socket_file.exists(): + raise ValueError(f"Socket file {str(socket_file.resolve())} does not exist") + + app.socket_file = str(socket_file) + app.run() + + if ctx: + import click + click.echo("Exited application") + + if hasattr(app, "server_IO"): + # This is duplicated above in case of partial shutdown + app.server.shutdown(sockets.SHUT_RDWR) + app.server_IO.close() + app.server.close() + + # Clean up any residual workers + app.workers.cancel_all() + +if __name__ == "__main__": + socket_file = Path("jumpstart.sock") + for arg in argv: + if arg.startswith("socketfile="): + socket_file = Path(arg.split("=", 1)[1]) + + run_app(socket_file) diff --git a/jumpstart/tui.tcss b/jumpstart/tui.tcss new file mode 100644 index 000000000..2c431a72a --- /dev/null +++ b/jumpstart/tui.tcss @@ -0,0 +1,46 @@ +#stdout { + width: 100% +} + +#command_line { + margin-top: 1; +} + +.sidebar { + width: auto; + margin: 1; + padding: 1; + background: $boost; + height: 100%; + min-height: 30; +} + +.status_container { + width: 100w; + height: auto; + content-align: center top; + border-bottom: heavy $background; +} + +.info_container { + padding: 0; + margin: 1; + height: auto; +} + +.status_label { + dock: right; + width: 2; + background: red; + color: black; + text-align: center; +} + +.label_up { + background: green; + color: green; +} + +#exitbtn { + dock: bottom; +} diff --git a/jumpstart/utils.py b/jumpstart/utils.py new file mode 100644 index 000000000..e1d9de5cf --- /dev/null +++ b/jumpstart/utils.py @@ -0,0 +1,35 @@ +from functools import wraps +from threading import RLock + +def synchronized(func): + lock = RLock() + @wraps(func) + def call_sync(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return call_sync + +class CallbackCollection: + def __init__(self, **cbs) -> None: + for name, cb in cbs.items(): + if not callable(cb): + raise TypeError("A callback must be callable") + setattr(self, name, cb) + + def register(self, name, cb): + if not callable(cb): + raise TypeError("A callback must be callable") + setattr(self, name, cb) + + def register_all(self, **cbs): + for name, cb in cbs.items(): + if not callable(cb): + raise TypeError("A callback must be callable") + setattr(self, name, cb) + +class UniversalPlaceholder: + def __getattr__(self, name, default=...): + return self + + def __call__(self, *args, **kwargs): + pass diff --git a/setup.py b/setup.py index a27f0463f..8591f483e 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "flask_graphql", "wheel", "sendgrid", + "textual>=0.73.0", "alembic==1.8.1", # 1.8.1 "coloredlogs==15.0", # 15.0.1 "Beaker==1.11.0", # 1.11.0