From bcf7f55e6d97a51aaae70fa378a75bc7be169070 Mon Sep 17 00:00:00 2001 From: PrivacyGo-PETPlatform Date: Thu, 26 Sep 2024 18:23:53 +0800 Subject: [PATCH] feat: version 0.1.1 --- .pre-commit-config.yaml | 14 - .pylintrc | 569 ------------------ CHANGELOG.md | 25 + README.md | 127 ++-- client/client/__init__.py | 17 + client/client/cli.py | 120 ++++ client/client/client.py | 85 +++ .../client/utils}/__init__.py | 0 client/client/utils/request_utils.py | 106 ++++ client/setup.py | 28 + requirements.txt | 1 + src/app.py | 11 +- src/config/config_manager.py | 21 +- src/config/global_config.py | 17 +- src/config/job_context.py | 96 ++- src/config/mission_context.py | 65 +- src/constants.py | 8 +- src/decorators/__init__.py | 13 + src/decorators/decorators.py | 121 ++++ src/exceptions/__init__.py | 13 + src/exceptions/exceptions.py | 39 ++ src/extensions.py | 23 +- src/initialize_database.py | 164 +++-- src/initialize_jwt.py | 39 ++ src/job_manager/core.py | 379 ++++++------ src/job_manager/dag.py | 26 +- .../core.py => job_manager/task.py} | 55 +- src/models/base.py | 1 + src/models/job.py | 5 + src/models/mission.py | 1 + src/models/mission_context.py | 1 + src/models/task.py | 37 ++ src/models/user.py | 55 ++ src/network/request.py | 85 ++- src/settings.py | 12 +- src/utils/request_utils.py | 106 ++++ src/views/__init__.py | 13 + src/{views.py => views/default_views.py} | 26 +- src/views/v1.py | 102 ++++ test/operators/test_executor.py | 2 +- 40 files changed, 1540 insertions(+), 1088 deletions(-) delete mode 100644 .pylintrc create mode 100644 CHANGELOG.md create mode 100644 client/client/__init__.py create mode 100644 client/client/cli.py create mode 100644 client/client/client.py rename {src/task_executor => client/client/utils}/__init__.py (100%) create mode 100644 client/client/utils/request_utils.py create mode 100644 client/setup.py create mode 100644 src/decorators/__init__.py create mode 100644 src/decorators/decorators.py create mode 100644 src/exceptions/__init__.py create mode 100644 src/exceptions/exceptions.py create mode 100644 src/initialize_jwt.py rename src/{task_executor/core.py => job_manager/task.py} (66%) create mode 100644 src/models/user.py create mode 100644 src/utils/request_utils.py create mode 100644 src/views/__init__.py rename src/{views.py => views/default_views.py} (86%) create mode 100644 src/views/v1.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 98c27a5..7d37a3d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,17 +23,3 @@ repos: entry: yapf args: ["--style=.style.yapf", "-i"] types: [python] - - repo: https://github.com/pylint-dev/pylint/ - rev: v3.0.2 - hooks: - - id: pylint - name: pylint - entry: pylint - language: python - types: [python] - args: - [ - "-rn", # Only display messages - "-sn", # Don't display the score - "--rcfile=.pylintrc" - ] diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 60e7f8e..0000000 --- a/.pylintrc +++ /dev/null @@ -1,569 +0,0 @@ -[MAIN] - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Files or directories to be skipped. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the ignore-list. The -# regex matches against paths and can be in Posix or Windows format. -ignore-paths= - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns=^\.# - -# Pickle collected data for later comparisons. -persistent=yes - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - pylint.extensions.check_elif, - pylint.extensions.bad_builtin, - pylint.extensions.docparams, - pylint.extensions.for_any_all, - pylint.extensions.set_membership, - pylint.extensions.code_style, - pylint.extensions.overlapping_exceptions, - pylint.extensions.typing, - pylint.extensions.redefined_variable_type, - pylint.extensions.comparison_placement, - pylint.extensions.broad_try_clause, - pylint.extensions.dict_init_mutate, - pylint.extensions.consider_refactoring_into_while_condition, - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-allow-list= - -# Minimum supported python version -py-version = 3.8.0 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# Specify a score threshold under which the program will exit with error. -fail-under=10.0 - -# Return non-zero exit code if any of these messages/categories are detected, -# even if score is above --fail-under value. Syntax same as enable. Messages -# specified are enabled, while categories only check already-enabled messages. -fail-on= - -# Clear in-memory caches upon conclusion of linting. Useful if running pylint in -# a server-like mode. -clear-cache-post-run=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -# confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable= - use-symbolic-message-instead, - useless-suppression, - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then re-enable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" - -disable= - attribute-defined-outside-init, - invalid-name, - missing-docstring, - protected-access, - too-few-public-methods, - # handled by black - format, - # We anticipate #3512 where it will become optional - fixme, - consider-using-assignment-expr, - logging-fstring-interpolation, - unspecified-encoding, - redefined-variable-type, - too-many-locals, - too-many-instance-attributes, - too-many-public-methods, - import-error, - no-name-in-module, - unsupported-binary-operation, - redefined-outer-name, - redefined-builtin, - consider-alternative-union-syntax, - consider-using-alias, - broad-exception-caught, - too-many-try-statements, - import-outside-toplevel, - broad-exception-raised, - misplaced-bare-raise, - no-else-break, - no-else-return, - cyclic-import - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables 'fatal', 'error', 'warning', 'refactor', 'convention' -# and 'info', which contain the number of messages in each category, as -# well as 'statement', which is the total number of statements analyzed. This -# score is used by the global evaluation report (RP0004). -evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - -# Activate the evaluation score. -score=yes - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME,XXX,TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=6 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=yes - -# Signatures are removed from the similarity computation -ignore-signatures=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of names allowed to shadow builtins -allowed-redefined-builtins= - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=120 - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Maximum number of lines in a module -max-module-lines=2000 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=i,j,k,ex,Run,_ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Bad variable names which should always be refused, separated by a comma -bad-names=foo,bar,baz,toto,tutu,tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names -function-rgx=[a-z_][a-z0-9_]{2,30}$ - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names -variable-rgx=[a-z_][a-z0-9_]{2,30}$ - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names -const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names -attr-rgx=[a-z_][a-z0-9_]{2,}$ - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names -argument-rgx=[a-z_][a-z0-9_]{2,30}$ - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names -class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ - -# Naming style matching correct class constant names. -class-const-naming-style=UPPER_CASE - -# Regular expression matching correct class constant names. Overrides class- -# const-naming-style. -#class-const-rgx= - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names -inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names -class-rgx=[A-Z_][a-zA-Z0-9]+$ - - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names -module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names -method-rgx=[a-z_][a-z0-9_]{2,}$ - -# Regular expression matching correct type variable names -#typevar-rgx= - -# Regular expression which should only match function or class names that do -# not require a docstring. Use ^(?!__init__$)_ to also check __init__. -no-docstring-rgx=__.*__ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# List of decorators that define properties, such as abc.abstractproperty. -property-classes=abc.abstractproperty - - -[TYPECHECK] - -# Regex pattern to define which classes are considered mixins if ignore-mixin- -# members is set to 'yes' -mixin-class-rgx=.*MixIn - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=REQUEST,acl_users,aq_parent,argparse.Namespace - -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators=contextlib.contextmanager - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# List of comma separated words that should be considered directives if they -# appear and the beginning of a comment and should not be checked. -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pragma:,# noinspection - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file=.pyenchant_pylint_custom_dict.txt - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=2 - - -[DESIGN] - -# Maximum number of arguments for function / method -max-args = 9 - -# Maximum number of locals for function / method body -max-locals = 19 - -# Maximum number of return / yield for function / method body -max-returns=11 - -# Maximum number of branch for function / method body -max-branches = 20 - -# Maximum number of statements in function / method body -max-statements = 50 - -# Maximum number of attributes for a class (see R0902). -max-attributes=11 - -# Maximum number of statements in a try-block -max-try-statements = 7 - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__,__new__,setUp,__post_init__ - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict,_fields,_replace,_source,_make - -# Warn about protected attribute access inside special methods -check-protected-access-in-special-methods=no - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Allow explicit reexports by alias from a package __init__. -allow-reexport-from-package=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub,TERMIOS,Bastion,rexec - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=builtins.Exception - - -[TYPING] - -# Set to ``no`` if the app / library does **NOT** need to support runtime -# introspection of type annotations. If you use type annotations -# **exclusively** for type checking of an application, you're probably fine. -# For libraries, evaluate if some users what to access the type hints at -# runtime first, e.g., through ``typing.get_type_hints``. Applies to Python -# versions 3.7 - 3.9 -runtime-typing = no - - -[DEPRECATED_BUILTINS] - -# List of builtins function names that should not be used, separated by a comma -bad-functions=map,input - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit,argparse.parse_error - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[CODE_STYLE] - -# Max line length for which to sill emit suggestions. Used to prevent optional -# suggestions which would get split by a code formatter (e.g., black). Will -# default to the setting for ``max-line-length``. -#max-line-length-suggestions= diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4d766b2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,25 @@ +# List of Changes + +## Version 0.1.0 + +### Added + +- User APIs and schedule abilities for job management. +- Internal support & integration with PETML, PETSQL and PETAce. + +## Version 0.1.1 + +### Added + +- Implemented user authentication and multi-user isolation at the API level, enhancing overall system security and stability. +- Introduced a batch querying feature for jobs, enabling users to check the status of multiple jobs simultaneously. +- Launched support for Python SDK and command-line tools, providing users with more convenient ways to interact with our service. + +### Changed + +- Refined our user-facing API interfaces to adhere more closely to RESTful API naming conventions, thereby improving API readability and usability. +- Improved the output of the job detail query interface, making the returned results more clear and understandable. + +### Fixed + +- Resolved a previous issue that could potentially lead to database connection timeouts, thereby enhancing system reliability. diff --git a/README.md b/README.md index a10d2b2..b7ac519 100644 --- a/README.md +++ b/README.md @@ -198,83 +198,100 @@ docker-compose up -d ## User Manual -### How to Submit a Job +In version v0.1.0, task lifecycle management was managed via HTTP interface calls, requiring users to have specific background knowledge and input numerous parameters during operation. In the updated v0.1.1 version, we have introduced support for command-line interactive tools, added new interfaces, and optimized the output of some interfaces. These enhancements aim to improve user experience and facilitate better task management. -First you need prepare your request parameters. Here is a template for running a psi job. You can find more templates in [mission_templates](test/request). +### Installation -```json -{ - "mission_name": "psi", - "mission_version": 1, - "mission_params": { - "party_a": { - "column_name": "id", - "inputs": { - "data": "data/breast_hetero_mini_server.csv" - }, - "outputs": { - "data": "data/psi_result.csv" - } - }, - "party_b": { - "column_name": "id", - "inputs": { - "data": "data/breast_hetero_mini_client.csv" - }, - "outputs": { - "data": "data/psi_result.csv" - } - } - } -} -``` +| System | Toolchain | +|-----------|---------------------------| +| Linux/Mac | Python 3.x, pip(>=23.3.1) | -Then, you can submit a job by executing: +### How to Install + +First you will need the wheel package of the petplatform client. Then you can install it with the following commands: ```bash -curl http://${HOSTNAME}/job/submit -H 'Content-Type: application/json' -d@${PATH_TO_YOUR_CONFIG_FILE} +# enter operating directory +cd my_project +# create python virtual environment +python3 -m venv env +# activate +source env/bin/activate +# pip install petplatform-cli +pip install petplatform_client-0.1.0-py3-none-any.whl ``` -Normally you will receive a response like the following. -You need to record your `job_id` for later to query the job status. +You can run the following command to check whether installation succeeded: +```bash +petplatform-cli +``` -```json -{ - "success": true, - "job_id": "job_id" -} +### Initialization +We have added an authentication mechanism named JWT Token for better service quality and security. Users need to initialize the commandline tool before using it to manage their jobs. Run the following command: +```bash +petplatform-cli init ``` +In the following interactions, users need to enter the platform's URL and their own JWT token. + +Notes: +- Users only need to perform the initialization operation once. Once the user completes the input, the above configuration information will be saved to the .env file in the current directory. In subsequent uses, as long as the directory and .env file have not changed, users do not need to perform the initialization operation again. +- If you need to change the configuration items, just re-execute the init command and input the new configuration items, the original configuration will be overwritten. You can also directly access the .env file to view and modify configuration items. +- If the commandline tool is not correctly initialized, it might not work. Please contact us for technical support if you have any questions. + +### Job Management via Commandline Tool +#### Get Help Message +```bash +# get commandlie tool help message +petplatform-cli --help -### How to Query Job Status +# get help message for subcommands +petplatform-cli [subcommand name] --help +``` -You can query job status by executing the following command. According to the job status, your response may look different. +#### Submit a Job ```bash -curl http://${HOSTNAME}/job/status?job_id=${YOUR_JOB_ID} +# If your job parameters is a json file, e.g. /tmp/params.json +petplatform-cli submit --json-file ${YOUR_JSON_FILE} + +# If your job parameters is a json string, e.g. {"mission_name": "psi"} +petplatform-cli submit --json-string ${YOUR_JSON_STRING} ``` +#### List History Jobs + +```bash +# By default, jobs submited in the past 24 hours will be shown (10 at most) +petplatform-cli get-jobs + +# Filter by job status, e.g. if you only want to see running jobs: +petplatform-cli get-jobs --status RUNNING + +# Show earlier jobs, e.g. if you want to see successful jobs in the last 48 hours: +petplatform-cli get-jobs --status SUCCESS --hours 48 + +# Enlarge shown jobs number limitation, e.g. if you have 10-20 successful jobs and you want to show all: +petplatform-cli get-jobs --status SUCCESS --hours 48 --limit 20 +``` -### User APIs -| Endpoint | Method | Description | Request Params | Response | Example | -|---------------|--------|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------| -| `/job/submit` | `POST` | Submit a computing job | `mission_name`: (string) The name of the mission (e.g., "psi").
`mission_version`: (integer) The version of the mission.
`mission_params`: (object) The parameters of the mission. | Success: `success` (boolean), `job_id` (string)
Failure: `success` (boolean), `error_message` (string) | `{ "mission_name": "psi", "mission_version": 1, "mission_params": {...} }` | -| `/job/status` | `GET` | Query job status | `job_id`: (string) The ID of the job. | `status` (string): The status of the job ("RUNNING", "FAILED", "SUCCESS").
`progress` (string): The progress of the job.
`task_status` (object): The status of the task. | `{ "job_id": "xxx" }` | -| `/job/kill` | `POST` | Terminate a running job | `job_id`: (string) The ID of the job. | Success: `success` (boolean)
Failure: `success` (boolean), `error_message` (string) | `{ "job_id": "xxx" }` | -| `/job/rerun` | `POST` | Rerun a failed job | `job_id`: (string) The ID of the job. | Success: `success` (boolean)
Failure: `success` (boolean), `error_message` (string) | `{ "job_id": "xxx" }` | +#### Show Single Job Details +```bash +petplatform-cli get-job ${YOUR_JOB_ID} +``` -### Trouble Shooting +#### Stop a Running Job -If you encounter problems while using the PETPlatform service, you can follow the steps below for self-check. -If you still have more questions, you may report your bug to the community. +```bash +petplatform-cli cancel ${YOUR_JOB_ID} +``` -1. Ensure that the PETPlatform Docker image you are using is up-to-date. -2. Check if your configuration files (e.g., `party.json`) are correct. Ensure all keys and values are as expected, with no spelling errors or missing items. -3. Check if your `docker-compose.yml` file is correct. Ensure all services, environment variables, volumes, and port mappings are configured correctly. -4. Check if your containers are running. -5. Refer to the error message for more details about the specific error. +Rerun a Failed/Cancaled Job +```bash +petplatform-cli rerun ${YOUR_JOB_ID} +``` ## Contribution diff --git a/client/client/__init__.py b/client/client/__init__.py new file mode 100644 index 0000000..ea0b5f2 --- /dev/null +++ b/client/client/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .cli import cli +from .client import PlatformClient + +__all__ = ["cli", "PlatformClient"] diff --git a/client/client/cli.py b/client/client/cli.py new file mode 100644 index 0000000..2e3d3aa --- /dev/null +++ b/client/client/cli.py @@ -0,0 +1,120 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json + +import click +from dotenv import load_dotenv, set_key + +from client.client import PlatformClient + +load_dotenv() + + +@click.group() +@click.pass_context +def cli(ctx): + server_url = os.getenv("SERVER_URL", None) + jwt_token = os.getenv("JWT_TOKEN", None) + if ctx.invoked_subcommand != "init": + if not server_url or not jwt_token: + click.echo("Server URL or JWT token not set. Please run the 'init' command to set them.") + ctx.exit() + else: + ctx.ensure_object(dict) + ctx.obj["client"] = PlatformClient(server_url, jwt_token) + + +@cli.command(help="Initialize server_url and jwt_token") +@click.option("--server-url", prompt="Please enter server url", help="Server URL.") +@click.option("--jwt-token", prompt="Please enter jwt token", help="Client validation token.") +@click.pass_context +def init(ctx, server_url, jwt_token): + # Save the values to .env file + set_key(".env", "SERVER_URL", server_url) + set_key(".env", "JWT_TOKEN", jwt_token) + click.echo("Server URL and JWT token have been set.") + + ctx.ensure_object(dict) + ctx.obj["client"] = PlatformClient(server_url, jwt_token) + + +@cli.command(help="submit a new job") +@click.option("--json-file", type=click.Path(exists=True), default=None, help="path to a json file") +@click.option("--json-string", type=str, default=None, help="parameters as a json string") +@click.pass_context +def submit(ctx, json_file, json_string): + client = ctx.obj["client"] + if json_file is not None: + try: + with open(json_file, "r") as f: + params = json.load(f) + except Exception as e: + raise click.UsageError(f"not a valid json format file: {e}") + elif json_string is not None: + try: + params = json.loads(json_string) + except Exception as e: + raise click.UsageError(f"not a valid json format string: {e}") + else: + raise click.UsageError("you must provide either --json-file or --json-string.") + + success = client.submit(params) + click.echo(success) + + +@cli.command(help="cancel a running job") +@click.argument("job-id") +@click.pass_context +def cancel(ctx, job_id): + client = ctx.obj["client"] + success = client.cancel(job_id) + click.echo(success) + + +@cli.command(help="rerun a failed/canceled job") +@click.argument("job-id") +@click.pass_context +def rerun(ctx, job_id): + client = ctx.obj["client"] + success = client.rerun(job_id) + click.echo(success) + + +@cli.command(help="get job info") +@click.argument("job-id") +@click.pass_context +def get_job(ctx, job_id): + client = ctx.obj["client"] + job = client.get(job_id) + click.echo(job) + + +@cli.command(help="list a limited number of jobs submitted in the past hours with given status") +@click.option( + "--status", + type=str, + default=None, + help="only jobs with the given status will be shown, accept values e.g. [RUNNING, SUCCESS, FAILED, CANCELED]") +@click.option("--hours", type=int, default=24, help="only the jobs submitted within the past given hours will be shown") +@click.option("--limit", type=int, default=10, help="only show jobs within the given limit") +@click.pass_context +def get_jobs(ctx, status, hours, limit): + client = ctx.obj["client"] + response = client.get_all(status, hours, limit) + click.echo(response) + + +if __name__ == "__main__": + cli(obj={}) diff --git a/client/client/client.py b/client/client/client.py new file mode 100644 index 0000000..4204f9d --- /dev/null +++ b/client/client/client.py @@ -0,0 +1,85 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List + +from .utils.request_utils import get, post + + +class PlatformClient: + + def __init__(self, server_url: str, jwt_token: str): + self._server_url = server_url + self._jwt_token = jwt_token + + def __str__(self): + return f"PlatformClient instance with server_url={self._server_url} and jwt_token={self._jwt_token}" + + def __repr__(self): + return f"PlatformClient({self._server_url}, {self._jwt_token})" + + def _get_address(self) -> str: + return self._server_url + + def _get_headers(self) -> Dict: + return {"Authorization": f"Bearer {self._jwt_token}"} + + def submit(self, params: Dict) -> bool: + address = self._get_address() + headers = self._get_headers() + response = post(address, "api/v1/jobs", json=params, headers=headers) + if response.get("success") is not True: + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") + return True + + def rerun(self, job_id: str) -> bool: + address = self._get_address() + headers = self._get_headers() + response = post(address, f"api/v1/jobs/{job_id}/rerun", headers=headers) + if response.get("success") is not True: + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") + return True + + def cancel(self, job_id: str) -> bool: + address = self._get_address() + headers = self._get_headers() + response = post(address, f"api/v1/jobs/{job_id}/cancel", headers=headers) + if response.get("success") is not True: + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") + return True + + def get(self, job_id: str) -> Dict: + address = self._get_address() + headers = self._get_headers() + response = get(address, f"api/v1/jobs/{job_id}", headers=headers, return_json=True) + if response.get("success") is not True: + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") + return response["job"] + + def get_all(self, status: str = None, hours: int = 24, limit: int = 10) -> List: + address = self._get_address() + headers = self._get_headers() + params = {"hours": hours, "limit": limit} + if status is not None: + if not isinstance(status, str): + raise ValueError("status must be a string") + params["status"] = status + response = get(address, f"api/v1/jobs", headers=headers, params=params, return_json=True) + if response.get("success") is not True: + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") + return response["jobs"] diff --git a/src/task_executor/__init__.py b/client/client/utils/__init__.py similarity index 100% rename from src/task_executor/__init__.py rename to client/client/utils/__init__.py diff --git a/client/client/utils/request_utils.py b/client/client/utils/request_utils.py new file mode 100644 index 0000000..eb35d4f --- /dev/null +++ b/client/client/utils/request_utils.py @@ -0,0 +1,106 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import requests +from requests.exceptions import Timeout + + +def send_request(method: str, + address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + url = "{address}/{endpoint}".format(address=address, endpoint=endpoint.lstrip('/')) + request_headers = {"Content-Type": "application/json"} + if headers: + request_headers.update(headers) + try: + logging.debug(f"send {method} request to {url}, params={params}, json={json}, data={data}, headers={headers}") + response = requests.request(method, + url, + params=params, + json=json, + data=data, + headers=request_headers, + timeout=timeout) + except Timeout: + logging.error(f"{method} request timed out: {address}/{endpoint}, headers={headers}, json={json}, data={data}") + raise + except Exception: + logging.exception(f"{method} fail: {address}/{endpoint}, headers={headers}, json={json}, data={data}") + raise + + if response.status_code >= 400: + logging.error(f"{method} request error status {response.status_code}: " + f"{address}/{endpoint}, " + f"headers={headers}, " + f"json={json}, " + f"data={data}") + response.raise_for_status() + + logging.debug(f"response: {response.json()}") + return response.json() if return_json else response + + +def delete(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + timeout=10, + return_json=True): + return send_request("DELETE", address, endpoint, params, headers, json, None, timeout, return_json) + + +def get(address: str, endpoint: str, params: Dict = None, headers: Dict = None, timeout=10, return_json=False): + return send_request("GET", address, endpoint, params, headers, None, None, timeout, return_json=return_json) + + +def patch(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("PATCH", address, endpoint, params, headers, json, data, timeout, return_json) + + +def post(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("POST", address, endpoint, params, headers, json, data, timeout, return_json) + + +def put(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("PUT", address, endpoint, params, headers, json, data, timeout, return_json) diff --git a/client/setup.py b/client/setup.py new file mode 100644 index 0000000..f9089cd --- /dev/null +++ b/client/setup.py @@ -0,0 +1,28 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from setuptools import setup, find_packages + +setup( + name='petplatform-client', + version='0.1.0', + description='SDK and Commandline Tools for PETPlatform', + author="PrivacyGo-PETPlatform", + author_email="privacygo-petplatform@tiktok.com", + packages=find_packages('.', exclude=['tests']), + package_dir={'': '.'}, + install_requires=['click', 'python-dotenv', 'requests'], + entry_points={ + 'console_scripts': ['petplatform-cli=client.cli:cli'], + }, +) diff --git a/requirements.txt b/requirements.txt index f408f48..83399a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ PyMySQL~=1.1.1 PyYAML~=6.0.1 requests~=2.32.0 SQLAlchemy~=2.0.29 +setuptools>=73.0.1 diff --git a/src/app.py b/src/app.py index bec6293..a93b4b1 100644 --- a/src/app.py +++ b/src/app.py @@ -14,16 +14,19 @@ import logging.config import flask +from flask_sqlalchemy import SQLAlchemy -from extensions import db -from views import views +from models.base import Base import settings +from views.default_views import default_views +from views.v1 import v1 logging.config.dictConfig(settings.LOGGING_CONFIG) app = flask.Flask(__name__) -app.register_blueprint(views) +app.register_blueprint(default_views) +app.register_blueprint(v1) app.config['SQLALCHEMY_DATABASE_URI'] = settings.PLATFORM_DB_URI -db.init_app(app) +db = SQLAlchemy(app=app, model_class=Base) if __name__ == '__main__': # Never run debug mode in production environment! diff --git a/src/config/config_manager.py b/src/config/config_manager.py index 196734f..406005c 100644 --- a/src/config/config_manager.py +++ b/src/config/config_manager.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging - -from extensions import get_session_ins from .global_config import GlobalConfig from .job_context import JobContext from .mission_context import MissionContext @@ -41,36 +38,24 @@ def __init__(self, mission_name: str, job_id: str): self.mission_name = mission_name self.job_id = job_id - self._session = None self._global_config = None self._mission_context = None self._job_context = None - @property - def session(self): - if self._session is None: - self._session = get_session_ins() - return self._session - @property def global_config(self): if self._global_config is None: - self._global_config = GlobalConfig(self.session) + self._global_config = GlobalConfig() return self._global_config @property def mission_context(self): if self._mission_context is None: - self._mission_context = MissionContext(self.session, self.mission_name) + self._mission_context = MissionContext(self.mission_name) return self._mission_context @property def job_context(self): if self._job_context is None: - self._job_context = JobContext(self.session, self.job_id) + self._job_context = JobContext(self.job_id) return self._job_context - - def close(self): - if self._session: - self._session.close() - logging.info("config manager session closed") diff --git a/src/config/global_config.py b/src/config/global_config.py index 6509d63..f8b3900 100644 --- a/src/config/global_config.py +++ b/src/config/global_config.py @@ -18,16 +18,19 @@ class GlobalConfig: - def __init__(self, session): - self.session = session + def __init__(self): + from extensions import get_session_maker + self.session_maker = get_session_maker() def get(self, key: str) -> Union[str, None]: - record = self.session.query(GlobalConfigTable).filter_by(config_key=key).first() - return record.config_value if record else None + with self.session_maker() as session: + record = session.query(GlobalConfigTable).filter_by(config_key=key).first() + return record.config_value if record else None def get_all(self, keys: Iterable[str]) -> Dict[str, str]: ret = {} - for key in keys: - record = self.session.query(GlobalConfigTable).filter_by(config_key=key).first() - ret[key] = record.config_value if record else None + with self.session_maker() as session: + for key in keys: + record = session.query(GlobalConfigTable).filter_by(config_key=key).first() + ret[key] = record.config_value if record else None return ret diff --git a/src/config/job_context.py b/src/config/job_context.py index ffff258..48d6aea 100644 --- a/src/config/job_context.py +++ b/src/config/job_context.py @@ -24,13 +24,16 @@ class JobContext: - def __init__(self, session, job_id: str): - self.session = session + def __init__(self, job_id: str): + from extensions import get_session_maker + self.session_maker = get_session_maker() self.job_id = job_id def get(self, key: str, party: str = None) -> Union[str, Dict, None]: - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"{self.job_id} not found") + with self.session_maker() as session: + job = session.query(Job).filter_by(job_id=self.job_id).first() + if job is None: + raise ValueError(f"{self.job_id} not found") context: Dict = json.loads(job.job_context) # select search domain search_domain = [settings.PARTY, "common"] if party is None else [party] @@ -52,52 +55,43 @@ def get(self, key: str, party: str = None) -> Union[str, Dict, None]: return None def set(self, key: str, value: Union[str, Dict, List], party: str, max_retry=3) -> bool: - keys, updated_context = key.split("."), {party: {}} - cursor = updated_context[party] - for k in keys[:-1]: - cursor[k] = {} - cursor = cursor[k] - cursor[keys[-1]] = value - for _ in range(max_retry): - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"{self.job_id} not found") - job_context: Dict = json.loads(job.job_context) - assert party in job_context, ValueError(f"party {party} not found") - assert isinstance(job_context[party], dict), ValueError(f"job_context[{party}] is not a dict") - deep_merge(job_context, updated_context) - job.job_context = json.dumps(job_context) - try: - session_commit_with_retry(self.session) - return True - except StaleDataError: - # This happens when two jobs try to modify the same record, - # this modification fails due to the optimistic lock. - # Retry until reach max retry times. - self.session.rollback() - continue - return False - - def get_all(self) -> Dict: - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"{self.job_id} not found") - context: Dict = json.loads(job.job_context) - return context + with self.session_maker() as session: + keys, updated_context = key.split("."), {party: {}} + cursor = updated_context[party] + for key in keys[:-1]: + cursor[key] = {} + cursor = cursor[key] + cursor[keys[-1]] = value + for i in range(max_retry): + job = session.query(Job).filter_by(job_id=self.job_id).first() + assert job is not None, ValueError(f"{self.job_id} not found") + job_context: Dict = json.loads(job.job_context) + assert party in job_context, ValueError(f"party {party} not found") + assert isinstance(job_context[party], dict), ValueError(f"job_context[{party}] is not a dict") + deep_merge(job_context, updated_context) + job.job_context = json.dumps(job_context) + try: + session_commit_with_retry(session) + return True + except StaleDataError: + continue + return False def set_all(self, configs: Dict[str, Union[str, Dict, List]], party: str = "common", max_retry=3): - for k in configs: - assert "." not in k, ValueError(f"unexpected special character '.' in key {k}") - for _ in range(max_retry): - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"{self.job_id} not found") - job_context: Dict = json.loads(job.job_context) - assert party in job_context, ValueError(f"party {party} not found") - assert isinstance(job_context[party], dict), ValueError(f"job_context[{party}] is not a dict") - deep_merge(job_context[party], configs) - job.job_context = json.dumps(job_context) - try: - session_commit_with_retry(self.session) - return True - except StaleDataError: - self.session.rollback() - continue - return False + with self.session_maker() as session: + for k in configs: + assert "." not in k, ValueError(f"unexpected special character '.' in key {k}") + for i in range(max_retry): + job = session.query(Job).filter_by(job_id=self.job_id).first() + assert job is not None, ValueError(f"{self.job_id} not found") + job_context: Dict = json.loads(job.job_context) + assert party in job_context, ValueError(f"party {party} not found") + assert isinstance(job_context[party], dict), ValueError(f"job_context[{party}] is not a dict") + deep_merge(job_context[party], configs) + job.job_context = json.dumps(job_context) + try: + session_commit_with_retry(session) + return True + except StaleDataError: + continue + return False diff --git a/src/config/mission_context.py b/src/config/mission_context.py index f6d25c9..7e99174 100644 --- a/src/config/mission_context.py +++ b/src/config/mission_context.py @@ -23,39 +23,42 @@ class MissionContext: - def __init__(self, session, mission_name: str): - self.session = session + def __init__(self, mission_name: str): + from extensions import get_session_maker + self.session_maker = get_session_maker() self.mission_name = mission_name def get(self, key: str) -> Union[str, None]: - record = self.session.query(MissionContextTable).filter_by(config_key=key, - mission_name=self.mission_name).first() - if record is None: - return None - if record.expire_time < datetime.utcnow(): - return None - return record.config_value + with self.session_maker() as session: + record = session.query(MissionContextTable).filter_by(config_key=key, + mission_name=self.mission_name).first() + if record is None: + return None + if record.expire_time < datetime.utcnow(): + return None + return record.config_value def set(self, key: str, value: str, expire_time=TimeDuration.DAY) -> bool: - utcnow = datetime.utcnow() - new_expire_time = utcnow + timedelta(seconds=expire_time) - record = self.session.query(MissionContextTable).filter_by(config_key=key, - mission_name=self.mission_name).first() - if record is None: # create - record = MissionContextTable(config_key=key, - mission_name=self.mission_name, - config_value=value, - expire_time=new_expire_time) - self.session.add(record) - else: # update - record.config_value = value - record.expire_time = new_expire_time - try: - session_commit_with_retry(self.session) - return True - except StaleDataError: - # This happens when two jobs try to modify the same record, - # this modification fails due to the optimistic lock. - # We leave it the caller to handle, he may try to read it again. - self.session.rollback() - return False + with self.session_maker() as session: + utcnow = datetime.utcnow() + new_expire_time = utcnow + timedelta(seconds=expire_time) + record = session.query(MissionContextTable).filter_by(config_key=key, + mission_name=self.mission_name).first() + if record is None: # create + record = MissionContextTable(config_key=key, + mission_name=self.mission_name, + config_value=value, + expire_time=new_expire_time) + session.add(record) + else: # update + record.config_value = value + record.expire_time = new_expire_time + try: + session_commit_with_retry(session) + return True + except StaleDataError: + # This happens when two jobs try to modify the same record, + # this modification fails due to the optimistic lock. + # We leave it the caller to handle, he may try to read it again. + session.rollback() + return False diff --git a/src/constants.py b/src/constants.py index 1d7d549..985151e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -24,15 +24,15 @@ class TimeDuration: class Status: + CANC = "CANCELED" + FAIL = "FAILED" INIT = "INIT" RUNN = "RUNNING" - STOP = "STOPPED" - FAIL = "FAILED" SUCC = "SUCCESS" - allowed_status = [INIT, RUNN, STOP, FAIL, SUCC] + status = [CANC, FAIL, INIT, RUNN, SUCC] @classmethod def validate(cls, status: str): - assert status in cls.allowed_status, ValueError(f"invalid status key {status}") + assert status in status, ValueError(f"invalid status key {status}") return status diff --git a/src/decorators/__init__.py b/src/decorators/__init__.py new file mode 100644 index 0000000..4d6166b --- /dev/null +++ b/src/decorators/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/decorators/decorators.py b/src/decorators/decorators.py new file mode 100644 index 0000000..5c4968e --- /dev/null +++ b/src/decorators/decorators.py @@ -0,0 +1,121 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flask import g, request, jsonify +from functools import wraps +import jwt +import logging + +from exceptions.exceptions import BaseError, ValidationError, AuthorizationError +from extensions import get_session_maker +from models.user import User, Role +from models.job import Job +import settings + +session_maker = get_session_maker() + + +def jwt_required(f): + + @wraps(f) + def wrapper(*args, **kwargs): + token = None + headers = request.headers + + if "Authorization" in headers: + token = headers["Authorization"].replace("Bearer ", "") + if not token: + raise ValidationError("JWT token is missing") + try: + payload = jwt.decode(token, settings.SECRET, algorithms=["HS256"]) + if "name" not in payload: + raise ValueError + with session_maker() as session: + user = session.query(User).filter_by(name=payload["name"]).first() + if user and user.validate(): + g.validated_user = user.to_dict() + else: + raise + except Exception: + raise ValidationError("JWT token is invalid") + + return f(*args, **kwargs) + + return wrapper + + +def check_job_permission(f): + + @wraps(f) + def wrapper(*args, **kwargs): + try: + if not hasattr(g, "validated_user") or "name" not in g.validated_user: + raise ValueError + user_name = g.validated_user["name"] + job_id = kwargs.get("job_id") + if not job_id.startswith("j_"): + raise ValueError(f"invalid job_id {job_id} in request path") + with session_maker() as session: + job = session.query(Job).filter_by(job_id=job_id).first() + if job is None or job.user_name != user_name: + raise + except Exception: + raise AuthorizationError("Unauthorized operation") + + return f(*args, **kwargs) + + return wrapper + + +def is_node(f): + + @wraps(f) + def wrapper(*args, **kwargs): + try: + if not hasattr(g, "validated_user") or g.validated_user["role"] != Role.node: + raise ValueError + except Exception: + raise AuthorizationError("Unauthorized operation") + + return f(*args, **kwargs) + + return wrapper + + +def log_and_handle_exceptions(f): + + def log_and_return_error(error, code): + logging.exception(str(error)) + return jsonify({ + "success": False, + "error_message": str(error), + }), code + + @wraps(f) + def wrapper(*args, **kwargs): + + try: + body = request.get_json() + except Exception: + body = "" + try: + logging.info(f"Request to {request.url} with headers {request.headers} and body {body}") + response, code = f(*args, **kwargs) + logging.info(f"Response with status {response.status_code} and body {response.get_json()}") + return response, code + except BaseError as e: + return log_and_return_error(e, e.code) + except Exception as e: + return log_and_return_error(e, 500) + + return wrapper diff --git a/src/exceptions/__init__.py b/src/exceptions/__init__.py new file mode 100644 index 0000000..4d6166b --- /dev/null +++ b/src/exceptions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/exceptions/exceptions.py b/src/exceptions/exceptions.py new file mode 100644 index 0000000..8b0aa2e --- /dev/null +++ b/src/exceptions/exceptions.py @@ -0,0 +1,39 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BaseError(Exception): + + def __init__(self, message, code=500): + super().__init__(message) + self.message = message + self.code = code + + +class ValidationError(BaseError): + + def __init__(self, message): + super().__init__(message, code=401) + + +class AuthorizationError(BaseError): + + def __init__(self, message): + super().__init__(message, code=403) + + +class NotFoundError(BaseError): + + def __init__(self, message): + super().__init__(message) diff --git a/src/extensions.py b/src/extensions.py index 47790d1..50aa220 100644 --- a/src/extensions.py +++ b/src/extensions.py @@ -12,34 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from flask_sqlalchemy import SQLAlchemy from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from models.base import Base import settings -# init db -db = SQLAlchemy(model_class=Base) - def get_engine(db_uri=None, create_tables=False): db_uri = db_uri or settings.PLATFORM_DB_URI engine = create_engine(db_uri) if create_tables: + from models.base import Base + from models.global_config import GlobalConfig + from models.job import Job + from models.mission import Mission + from models.mission_context import MissionContext + from models.task import Task + from models.user import User Base.metadata.create_all(engine) return engine -def get_session_cls(db_uri=None, create_tables=False): +def get_session_maker(db_uri=None, create_tables=False): engine = get_engine(db_uri, create_tables) return sessionmaker(bind=engine) - - -def get_session_ins(db_uri=None, create_tables=False): - engine = get_engine(db_uri, create_tables) - return sessionmaker(bind=engine)() - - -def get_flask_session(): - return db.session diff --git a/src/initialize_database.py b/src/initialize_database.py index ad16d44..777e7d0 100644 --- a/src/initialize_database.py +++ b/src/initialize_database.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime, timedelta import os import glob import json @@ -19,67 +20,146 @@ import yaml from sqlalchemy import MetaData -from extensions import get_session_ins +from extensions import get_session_maker from models.task import Task from models.job import Job from models.mission import Mission from models.global_config import GlobalConfig from models.mission_context import MissionContext +from models.user import User, Status + + +def create_users(): + return [ + User(name="user_0", status=Status.normal, role="OPERATOR"), + User(name="user_1", status=Status.normal, role="OPERATOR") + ] + + +def create_missions(): + missions = [] + mission_template_dir = "./test/missions" if platform.system().lower() == "darwin" else "/app/missions" + for idx, filename in enumerate(glob.glob(os.path.join(mission_template_dir, '*.yml'))): + with open(filename, "r") as rf: + content = yaml.safe_load(rf.read()) + meta = content.get("meta", {}) + name = meta.get("name", os.path.basename(filename)) + version = meta.get("version", 1) + missions.append(Mission(name=name, version=version, dag=json.dumps(content))) + return missions + + +def create_jobs(): + job_context = json.dumps({"party_a": {}, "party_b": {}, "common": {"__user_input": {}, "job_id": "j_1234"}}) + join_parties = json.dumps(["party_a", "party_b"]) + + jobs = [ + Job(job_id="j_1234", + mission_name="psi", + mission_version=1, + job_context=job_context, + main_party="party_a", + join_parties=join_parties, + status="RUNNING", + user_name="user_0"), + Job(job_id="j_1235", + mission_name="psi", + mission_version=1, + job_context=job_context, + main_party="party_a", + join_parties=join_parties, + status="FAILED", + user_name="user_0"), + Job(job_id="j_1236", + mission_name="psi", + mission_version=1, + job_context=job_context, + main_party="party_a", + join_parties=join_parties, + status="SUCCESS", + user_name="user_0"), + Job(job_id="j_1237", + mission_name="psi", + mission_version=1, + job_context=job_context, + main_party="party_a", + join_parties=join_parties, + status="SUCCESS", + user_name="user_1") + ] + + tasks = [ + Task(name="psi_a", job_id="j_1234", party="party_a", status="RUNNING", start_time=datetime.utcnow()), + Task(name="psi_b", job_id="j_1234", party="party_b", status="RUNNING", start_time=datetime.utcnow()), + Task(name="psi_a", + job_id="j_1235", + party="party_a", + status="FAILED", + start_time=datetime.utcnow(), + end_time=datetime.utcnow() + timedelta(seconds=5), + errors="error 0"), + Task(name="psi_b", job_id="j_1235", party="party_b", status="RUNNING", start_time=datetime.utcnow()), + Task( + name="psi_a", + job_id="j_1236", + party="party_a", + status="SUCCESS", + start_time=datetime.utcnow(), + end_time=datetime.utcnow() + timedelta(seconds=5), + ), + Task( + name="psi_b", + job_id="j_1236", + party="party_b", + status="SUCCESS", + start_time=datetime.utcnow(), + end_time=datetime.utcnow() + timedelta(seconds=6), + ), + Task( + name="psi_a", + job_id="j_1237", + party="party_a", + status="SUCCESS", + start_time=datetime.utcnow(), + end_time=datetime.utcnow() + timedelta(seconds=5), + ), + Task( + name="psi_b", + job_id="j_1237", + party="party_b", + status="SUCCESS", + start_time=datetime.utcnow(), + end_time=datetime.utcnow() + timedelta(seconds=6), + ), + ] + return jobs, tasks def initialize_database(url): - with get_session_ins(url, create_tables=True) as session: - # initialize GlobalConfig - party = session.query(GlobalConfig).filter_by(config_key="party").first() - if party is None: - session.add(GlobalConfig(config_key="party", config_value="party_a")) - party_address = session.query(GlobalConfig).filter_by(config_key="party_address").first() - if party_address is None: - session.add( - GlobalConfig(config_key="party_address", - config_value=json.dumps({ - "party_a": { - "petplatform": { - "url": "" - } - }, - "party_b": { - "petplatform": { - "url": "" - } - } - }))) - - # initialize Mission - mission_template_dir = "../missions" if platform.system().lower() == "darwin" else "/app/missions" - for _, filename in enumerate(glob.glob(os.path.join(mission_template_dir, '*.yml'))): - with open(filename, "r") as rf: - content = yaml.safe_load(rf.read()) - meta = content.get("meta", {}) - name = meta.get("name", os.path.basename(filename)) - version = meta.get("version", 1) - mission = session.query(Mission).filter_by(name=name, version=version).first() - if mission is None: - session.add(Mission(name=name, version=version, dag=json.dumps(content))) + with get_session_maker(url, create_tables=True)() as session: + users = create_users() + missions = create_missions() + jobs, tasks = create_jobs() + session.add_all(users) + session.add_all(missions) + session.add_all(jobs) + session.add_all(tasks) session.commit() def clear_database(url): - all_tables = [GlobalConfig, MissionContext, Mission, Job, Task] + all_tables = [GlobalConfig, MissionContext, Mission, Job, Task, User] meta = MetaData() - with get_session_ins(url) as session: + with get_session_maker(url)() as session: meta.reflect(bind=session.bind) for table in all_tables: table_name = table.__tablename__ if table_name in meta.tables: - try: - session.query(table).delete() - session.commit() - except Exception as e: - print(e) + session.query(table).delete() + session.commit() if __name__ == '__main__': import settings - # clear_database(SQLALCHEMY_DATABASE_URI) + clear_database(settings.PLATFORM_DB_URI) initialize_database(settings.PLATFORM_DB_URI) diff --git a/src/initialize_jwt.py b/src/initialize_jwt.py new file mode 100644 index 0000000..87fb073 --- /dev/null +++ b/src/initialize_jwt.py @@ -0,0 +1,39 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import secrets + +import jwt + + +def generate_secret(): + return secrets.token_hex(32) + + +def generate_token(secret, payload, algorithm='HS256'): + return jwt.encode(payload, secret, algorithm) + + +if __name__ == '__main__': + import os + + # Generate a secure random string of 32 bytes + # jwt_secret = generate_secret() + # print(jwt_secret) + + secret = os.environ.get('secret') + print(generate_token(secret, {'name': 'test_account_1'})) + print(generate_token(secret, {'name': 'cn_node_1'})) + print(generate_token(secret, {'name': 'va_node_1'})) + + print(generate_token(secret, {'name': 'user_0'})) diff --git a/src/job_manager/core.py b/src/job_manager/core.py index 9524c83..e8ccb82 100644 --- a/src/job_manager/core.py +++ b/src/job_manager/core.py @@ -11,101 +11,71 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime, timedelta import json import logging -from typing import Dict +from typing import Dict, List import multiprocessing as mp -from sqlalchemy.orm.exc import StaleDataError from constants import Status +from job_manager.dag import DAG, LogicTask +from job_manager.task import TaskExecutor from models.job import Job from models.mission import Mission from models.task import Task -from network import request_manager +from network.request import request_manager import settings from utils.db_utils import session_commit_with_retry from utils.deep_merge import deep_merge -from task_executor.core import TaskExecutor -from .dag import DAG, LogicTask class JobManager: _dag: "DAG" = None - def __init__(self, job_id: str, session=None): - if session is None: - from extensions import get_flask_session - session = get_flask_session() - self.session = session + def __init__(self, job_id: str): + from extensions import get_session_maker + self.session_maker = get_session_maker() self.job_id = job_id - def close(self): - if self.session: - self.session.close() - logging.info("job manager session closed") - @property def dag(self): if self._dag is None: - self._dag = DAG(self.session, self.job_id) + self._dag = DAG(self.job_id) return self._dag def _update_dag(self): - self._dag = DAG(self.session, self.job_id) - - def _get_mission(self, mission_name, mission_version): - if mission_version == "latest": - mission: "Mission" = self.session.query(Mission).filter_by(name=mission_name).order_by( - Mission.version.desc()).first() - else: - mission: "Mission" = self.session.query(Mission).filter_by(name=mission_name, - version=int(mission_version)).first() - assert mission is not None, ValueError(f"mission {mission_name}@v{mission_version} not found") + self._dag = DAG(self.job_id) + + def _get_mission(self, mission_name, mission_version) -> "Mission": + with self.session_maker() as session: + if mission_version == "latest": + mission: "Mission" = session.query(Mission).filter_by(name=mission_name).order_by( + Mission.version.desc()).first() + else: + mission: "Mission" = session.query(Mission).filter_by(name=mission_name, + version=int(mission_version)).first() + if mission is None: + raise ValueError(f"mission {mission_name}@v{mission_version} not found") return mission - def submit_job(self, params: Dict): - # does not allow parallel jobs - # running_jobs = self.session.query(Job).filter_by(status=Status.RUNN).all() - # if running_jobs: - # raise Exception("parallel jobs not allowed, please wait until the last job finish") - - # decide mission - mission_name = params.get("mission_name", "ecdh_psi_optimized") - mission_version = params.get("mission_version", "latest") - mission = self._get_mission(mission_name, mission_version) - - # get party info - main_party = params.get("main_party", settings.PARTY) - mission_dag = json.loads(mission.dag) - join_parties = list({operator["party"] for operator in mission_dag["operators"]}) - - # inform join parties to submit a new job with the same job id, mission name, and version - if main_party == settings.PARTY: - # set params - params["main_party"] = main_party - params["mission_name"] = mission.name - params["mission_version"] = str(mission.version) - params["job_id"] = self.job_id - for party in join_parties: - if party == settings.PARTY: - continue - request_manager.submit(party, params) - - # create job - mission_params = params.get("mission_params", {}) - job_context = {"common": {"__user_input": mission_params, "job_id": self.job_id}} - for party in join_parties: - job_context[party] = {} - job = Job(job_id=self.job_id, - mission_name=mission.name, - mission_version=mission.version, - job_context=json.dumps(job_context), - main_party=main_party, - join_parties=json.dumps(join_parties), - status=Status.RUNN) - # create tasks - tasks = [ + def _create_job(self, + mission, + job_context, + main_party: str, + join_parties: List[str], + user_name: str = None) -> "Job": + return Job(job_id=self.job_id, + mission_name=mission.name, + mission_version=mission.version, + job_context=json.dumps(job_context), + main_party=main_party, + join_parties=json.dumps(join_parties), + status=Status.RUNN, + user_name=user_name or "") + + def _create_tasks(self, mission_dag) -> List["Task"]: + return [ Task(name=operator["name"], job_id=self.job_id, party=operator["party"], @@ -113,139 +83,188 @@ def submit_job(self, params: Dict): status=Status.INIT) for operator in mission_dag["operators"] ] - # commit changes to db - self.session.add(job) - self.session.add_all(tasks) - session_commit_with_retry(self.session) - logging.info(f"created new job {self.job_id}:{mission_dag}@{mission_version}, job_context: {job_context}") + def submit(self, params: Dict, user_name: str = None): + with self.session_maker() as session: + # does not allow parallel jobs + running_jobs = session.query(Job).filter_by(status=Status.RUNN).all() + if len(running_jobs) >= settings.MAX_JOB_LIMIT: + raise Exception("running jobs has reached the upper limit, please try again later") + + # decide mission + mission_name = params.get("mission_name", "ecdh_psi_optimized") + mission_version = params.get("mission_version", "latest") + mission = self._get_mission(mission_name, mission_version) + + # decide parties + main_party = params.get("main_party", settings.PARTY) + mission_dag = json.loads(mission.dag) + join_parties = list({operator["party"] for operator in mission_dag["operators"]}) + + # decide job context + mission_params = params.get("mission_params", {}) + job_context = {party: {} for party in join_parties} + job_context["common"] = {"__user_input": mission_params, "job_id": self.job_id} + + # inform join parties to submit a new job with the same job id, mission name, and version + if main_party == settings.PARTY: + # set params + params["main_party"] = main_party + params["mission_name"] = mission.name + params["mission_version"] = str(mission.version) + params["job_id"] = self.job_id + for party in join_parties: + if party == settings.PARTY: + continue + request_manager.submit(party, params) + + # create job & task + job = self._create_job(mission, job_context, main_party, join_parties, user_name) + tasks = self._create_tasks(mission_dag) - # start tasks that are ready to run on your side - self.trigger_job() + # commit changes to db + session.add(job) + session.add_all(tasks) + session_commit_with_retry(session) + logging.info(f"created new job {self.job_id}:{mission_dag}@{mission_version}, job_context: {job_context}") - def kill(self): - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"job {self.job_id} not found") - - if job.main_party == settings.PARTY: - join_parties = json.loads(job.join_parties) - for party in join_parties: - if party == settings.PARTY: - continue - request_manager.kill(party, {"job_id": self.job_id}) - - job.status = Status.FAIL - tasks = self.session.query(Task).filter_by(job_id=self.job_id).all() - assert tasks is not None, ValueError(f"tasks for job {self.job_id} not found") - for task in tasks: - if task.status == Status.RUNN: - task.status = Status.FAIL - session_commit_with_retry(self.session) + self.trigger_job() def rerun(self): - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"job {self.job_id} not found") - - # can only rerun failed or stopped jobs - if job.status not in [Status.FAIL, Status.STOP]: - return - - if job.main_party == settings.PARTY: - join_parties = json.loads(job.join_parties) - for party in join_parties: - if party == settings.PARTY: - continue - request_manager.rerun(party, {"job_id": self.job_id}) - - job.status = Status.RUNN - tasks = self.session.query(Task).filter_by(job_id=self.job_id).all() - assert tasks is not None, ValueError(f"tasks for job {self.job_id} not found") - for task in tasks: - if task.status == Status.FAIL: - task.status = Status.INIT - session_commit_with_retry(self.session) + with self.session_maker() as session: + job = session.query(Job).filter_by(job_id=self.job_id).first() + if job is None: + raise ValueError(f"job {self.job_id} not found") + + # can only rerun failed or canceled job + if job.status not in [Status.FAIL, Status.CANC]: + return + + if job.main_party == settings.PARTY: + join_parties = json.loads(job.join_parties) + for party in join_parties: + if party == settings.PARTY: + continue + # request_manager.rerun(party, self.job_id) + + job.status = Status.RUNN + tasks = session.query(Task).filter_by(job_id=self.job_id).all() + if not tasks: + raise ValueError(f"tasks for job {self.job_id} not found") + for task in tasks: + if task.status in [Status.FAIL, Status.CANC]: + task.reset() + session_commit_with_retry(session) + self.trigger_job() - def get_status(self) -> Dict: - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job, ValueError(f"job {self.job_id} not found") - tasks = self.session.query(Task).filter_by(job_id=self.job_id).all() - assert tasks, ValueError(f"tasks for job id {self.job_id} not found") - task_status_map = {} - for task in tasks: - task_status_map[task.name] = Status.validate(task.status) + def cancel(self): + with self.session_maker() as session: + job = session.query(Job).filter_by(job_id=self.job_id).first() + if job is None: + raise ValueError(f"job {self.job_id} not found") + + if job.main_party == settings.PARTY: + join_parties = json.loads(job.join_parties) + for party in join_parties: + if party == settings.PARTY: + continue + request_manager.cancel(party, self.job_id) + + job.status = Status.CANC + tasks = session.query(Task).filter_by(job_id=self.job_id).all() + if not tasks: + raise ValueError(f"tasks for job {self.job_id} not found") + for task in tasks: + if task.status == Status.RUNN: + task.cancel() + session_commit_with_retry(session) + self.trigger_job() + + def get_job_details(self) -> Dict: + with self.session_maker() as session: + job = session.query(Job).filter_by(job_id=self.job_id).first() + if job is None: + raise ValueError(f"job {self.job_id} not found") + tasks = session.query(Task).filter_by(job_id=self.job_id).all() + if not tasks: + raise ValueError(f"tasks for job id {self.job_id} not found") + sorted_tasks = sorted(tasks, key=lambda task: task.start_time or datetime.utcnow()) + task_details = [task.details() for task in sorted_tasks] progress = format(len(list(filter(lambda x: x.status == Status.SUCC, tasks))) / len(tasks), ".2%") - return {"progress": progress, "job_status": job.status, "task_status": task_status_map} - - def update_task(self, task_name: str, task_status: str, external_context: Dict = None, max_retry=3): - try: - for i in range(max_retry): - task_status = Status.validate(task_status) - # only call this method on success or fail - assert task_status in [Status.SUCC, Status.FAIL], ValueError(f"unexpected task status {task_status}") - task = self.session.query(Task).filter_by(job_id=self.job_id, name=task_name).first() - assert task is not None, ValueError(f"{self.job_id}.{task_name} not found") - task.status = task_status - - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - assert job is not None, ValueError(f"{self.job_id} not found") - job_context: Dict = json.loads(job.job_context) + return {"job_id": self.job_id, "progress": progress, "job_status": job.status, "task_details": task_details} + + def get_jobs(self, user_name: str, status: str = None, hours: int = None, limit: int = 10) -> List: + with self.session_maker() as session: + query = session.query(Job).filter(Job.user_name == user_name) + if status is not None: + query = query.filter(Job.status == status) + if hours is not None: + start_time = datetime.utcnow() - timedelta(hours=hours) + query = query.filter(Job.create_time >= start_time) + jobs = query.limit(limit).all() + return [job.simple_to_dict() for job in jobs] + + def update_task(self, task_name: str, task_status: str, external_context: Dict = None, errors: str = None): + with self.session_maker() as session: + task_status = Status.validate(task_status) + task = session.query(Task).filter_by(job_id=self.job_id, name=task_name).first() + if task is None: + raise ValueError(f"{self.job_id}.{task_name} not found") + job = session.query(Job).filter_by(job_id=self.job_id).first() + if job is None: + raise ValueError(f"{self.job_id} not found") + job_context: Dict = json.loads(job.job_context) + + if task_status == Status.RUNN: + task.run() + elif task_status == Status.SUCC: + task.success() if external_context is not None: deep_merge(job_context, external_context) job.job_context = json.dumps(job_context) - try: - session_commit_with_retry(self.session) - except StaleDataError: - self.session.rollback() - if i < max_retry - 1: - continue - raise - - self.trigger_job() - - # broad task update + elif task_status == Status.FAIL: + task.fail(errors) + else: + raise ValueError(f"unexpected task status {task_status}") + session_commit_with_retry(session) + + # broadcast task update + params = {"task_status": task_status} if task.party == settings.PARTY: for party in json.loads(job.join_parties): if party == settings.PARTY: continue - params = { - "job_id": self.job_id, - "task_name": task_name, - "task_status": task_status, + if task_status == Status.SUCC: # sync filtered job context to partner - "job_context": { + params["job_context"] = { "common": job_context.get("common", {}), party: job_context.get(party, {}) } - } - request_manager.update_task(party, params) - except Exception as e: - logging.exception("update task status fail after max retry") - raise e + if task_status == Status.FAIL: + params["errors"] = errors + request_manager.update_task(party, self.job_id, task_name, params) + + if task_status in [Status.SUCC, Status.FAIL]: + self.trigger_job() def trigger_job(self): - self._update_dag() - status = self.dag.judge_job_status() - if status == Status.RUNN: - for task in self.dag.get_my_ready_tasks(): - self.start(task) - else: - job = self.session.query(Job).filter_by(job_id=self.job_id).first() - job.status = status - session_commit_with_retry(self.session) - if status == Status.FAIL: - for task in self.dag.get_my_running_tasks(): - self.stop_task(task) - - def start(self, task: "LogicTask"): - db_task = self.session.query(Task).filter_by(job_id=self.job_id, name=task.name).first() - db_task.status = Status.RUNN - try: - session_commit_with_retry(self.session) - except StaleDataError: - logging.warning(f"{self.job_id}.{task.name} task status already changed") - self.session.rollback() - return + # start tasks that are ready to run on your side + with self.session_maker() as session: + self._update_dag() + status = self.dag.judge_job_status() + if status == Status.RUNN: + for task in self.dag.get_my_ready_tasks(): + self.start_task(task) + else: + job = session.query(Job).filter_by(job_id=self.job_id).first() + job.status = status + session_commit_with_retry(session) + if status in [Status.FAIL, Status.CANC]: + for task in self.dag.get_my_running_tasks(): + self.stop_task(task) + + def start_task(self, task: "LogicTask"): task_executor = TaskExecutor(self.dag.mission_name, self.job_id, task) process = mp.Process(target=task_executor.start) process.start() diff --git a/src/job_manager/dag.py b/src/job_manager/dag.py index ee80962..fb782c4 100644 --- a/src/job_manager/dag.py +++ b/src/job_manager/dag.py @@ -35,19 +35,21 @@ class LogicTask: class DAG: - def __init__(self, session, job_id): + def __init__(self, job_id): self.mission_name = None self.mission_version = None self.job_id = job_id - self.session = session self._init_dag() def _init_dag(self): - job: "Job" = self.session.query(Job).filter_by(job_id=self.job_id).first() - self.mission_name, self.mission_version = job.mission_name, job.mission_version - mission: "Mission" = self.session.query(Mission).filter_by(name=self.mission_name, - version=self.mission_version).first() - tasks: List[Task] = self.session.query(Task).filter_by(job_id=self.job_id).all() + from extensions import get_session_maker + session_maker = get_session_maker() + with session_maker() as session: + job: "Job" = session.query(Job).filter_by(job_id=self.job_id).first() + self.mission_name, self.mission_version = job.mission_name, job.mission_version + mission: "Mission" = session.query(Mission).filter_by(name=self.mission_name, + version=self.mission_version).first() + tasks: List[Task] = session.query(Task).filter_by(job_id=self.job_id).all() dag: Dict = json.loads(mission.dag) self.tasks = { @@ -55,7 +57,7 @@ def _init_dag(self): LogicTask(v["name"], v["party"], v.get("args", {}), "", v.get("depends", []), v['class'], v['class_path']) for v in dag["operators"] } - diff_set = {k for k in self.tasks.keys() if k not in {v.name for v in tasks}} + diff_set = set(self.tasks.keys()).difference(set([v.name for v in tasks])) assert len(diff_set) == 0, ValueError(f"task missed: {diff_set}") # update task status @@ -90,10 +92,12 @@ def get_my_running_tasks(self) -> List["LogicTask"]: return running def judge_job_status(self) -> "Status": - num_init, num_running, num_success, num_failed = 0, 0, 0, 0 + num_init, num_running, num_success, num_failed, num_canceled = 0, 0, 0, 0, 0 for task in self.tasks.values(): - if task.status in [Status.STOP, Status.FAIL]: + if task.status in [Status.FAIL]: num_failed += 1 + if task.status in [Status.CANC]: + num_canceled += 1 if task.status in [Status.SUCC]: num_success += 1 if task.status in [Status.INIT]: @@ -103,6 +107,8 @@ def judge_job_status(self) -> "Status": if num_failed > 0: # some tasks failed or stopped return Status.FAIL + elif num_canceled > 0: + return Status.CANC elif num_success == len(self.tasks): # all tasks success return Status.SUCC diff --git a/src/task_executor/core.py b/src/job_manager/task.py similarity index 66% rename from src/task_executor/core.py rename to src/job_manager/task.py index dcb8723..cd03780 100644 --- a/src/task_executor/core.py +++ b/src/job_manager/task.py @@ -18,9 +18,8 @@ from typing import Dict from constants import Status -from extensions import get_session_ins from job_manager.dag import LogicTask -from network import network_config +from network.config import network_config import settings from utils.deep_merge import deep_merge from utils.path_utils import traverse_and_validate @@ -38,42 +37,40 @@ def __init__(self, mission_name, job_id: str, task: "LogicTask"): self.args = task.args self.start_time = time.time() - from config.config_manager import ConfigManager - from job_manager.core import JobManager - self.config_manager = ConfigManager(mission_name=self.mission_name, job_id=self.job_id) - self.job_manager = JobManager(job_id=self.job_id, session=get_session_ins()) - def start(self): - success = False + from job_manager.core import JobManager + from config.config_manager import ConfigManager + config_manager = ConfigManager(mission_name=self.mission_name, job_id=self.job_id) + job_manager = JobManager(job_id=self.job_id) + success, errors = False, None try: - # load operator class + job_manager.update_task(self.task_name, Status.RUNN) operator_class = self._load_class() - # prepare params - configmap = self._parse_configmap() - parsed_args: Dict = self._parse_args() - operator = operator_class(party=self.party, config_manager=self.config_manager, **parsed_args) - # run operator - logging.info(f"start {self.job_id}.{self.task_name}, configmap: {configmap}, args: {parsed_args}") + assert operator_class, RuntimeError(f"fail to load operator {self.class_name} from {self.class_path}") + configmap = self._parse_configmap(config_manager) + args_value_map: Dict = self._parse_args(config_manager=config_manager) + logging.info(f"ready to execute {self.job_id}.{self.task_name}, args: {args_value_map}") + operator = operator_class(party=self.party, config_manager=config_manager, **args_value_map) success = operator.run(configmap=configmap) except Exception as e: - logging.exception(f"{self.job_id}.{self.task_name} task start error: {e}") + logging.exception(f"execute task {self.job_id}.{self.task_name} fail") + errors = str(e) finally: exec_time = time.time() - self.start_time - logging.info(f"{self.job_id}.{self.task_name} finish, success: {success}, exec time: {exec_time}") - self.job_manager.update_task(self.task_name, Status.SUCC if success else Status.FAIL) - self._clear() - - def _clear(self): - self.job_manager.close() - self.config_manager.close() + logging.info( + f"{self.job_id}.{self.task_name} finish, success: {success}, exec time: {exec_time}, errors: {errors}") + try: + job_manager.update_task(self.task_name, Status.SUCC if success else Status.FAIL, errors=errors) + except Exception: + logging.exception(f"update task status fail") def _load_class(self): module = importlib.import_module(self.class_path) my_class = getattr(module, self.class_name) return my_class - def _parse_configmap(self) -> Dict: - job_context: Dict = self.config_manager.job_context.get_all() + def _parse_configmap(self, config_manager) -> Dict: + job_context: Dict = config_manager.job_context.get_all() join_parties = set(job_context.keys()) join_parties.remove("common") @@ -93,7 +90,7 @@ def _parse_configmap(self) -> Dict: def _validated_params(self, params: Dict): return traverse_and_validate(params, safe_workdir=settings.SAFE_WORK_DIR) - def _parse_args(self) -> Dict: + def _parse_args(self, config_manager) -> Dict: parsed_args = {} for args_key, args_value in self.args.items(): if isinstance(args_value, str): @@ -101,11 +98,11 @@ def _parse_args(self) -> Dict: # args_value in the form "${job_context.a.b.c}" real_key: str = re.findall(r'\${(.*?)}', args_value)[0] if real_key.startswith("job_context."): - args_value = self.config_manager.job_context.get(real_key[len("job_context."):]) + args_value = config_manager.job_context.get(real_key[len("job_context."):]) elif real_key.startswith("mission_context."): - args_value = self.config_manager.mission_context.get(real_key[len("mission_context."):]) + args_value = config_manager.mission_context.get(real_key[len("mission_context."):]) elif real_key.startswith("global_config."): - args_value = self.config_manager.global_config.get(real_key[len("global_config."):]) + args_value = config_manager.global_config.get(real_key[len("global_config."):]) else: raise Exception("no real args key context find") parsed_args[args_key] = args_value diff --git a/src/models/base.py b/src/models/base.py index 978d2d7..98e12c2 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -17,6 +17,7 @@ class BigIntOrInteger(TypeDecorator): impl = BigInteger + cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == 'sqlite': diff --git a/src/models/job.py b/src/models/job.py index 6d319c2..bad1dd1 100644 --- a/src/models/job.py +++ b/src/models/job.py @@ -14,6 +14,7 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Text, DateTime + from .base import Base, BigIntOrInteger @@ -29,9 +30,13 @@ class Job(Base): join_parties = Column(Text, nullable=False) main_host = Column(String(80)) status = Column(String(80), nullable=False) + user_name = Column(String(80), nullable=False, default="") create_time = Column(DateTime, default=datetime.utcnow) # create_time field update_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) version_id = Column(Integer, nullable=False, default=0) __mapper_args__ = {'version_id_col': version_id} + + def simple_to_dict(self): + return {"job_id": self.job_id, "status": self.status} diff --git a/src/models/mission.py b/src/models/mission.py index eaaed02..473455c 100644 --- a/src/models/mission.py +++ b/src/models/mission.py @@ -14,6 +14,7 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Text, DateTime, UniqueConstraint + from .base import Base, BigIntOrInteger diff --git a/src/models/mission_context.py b/src/models/mission_context.py index 3df9259..ae5de19 100644 --- a/src/models/mission_context.py +++ b/src/models/mission_context.py @@ -14,6 +14,7 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Text, DateTime, UniqueConstraint + from .base import Base, BigIntOrInteger diff --git a/src/models/task.py b/src/models/task.py index 68fe8e1..c648e5a 100644 --- a/src/models/task.py +++ b/src/models/task.py @@ -14,7 +14,9 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Text, DateTime, UniqueConstraint + from .base import Base, BigIntOrInteger +from constants import Status class Task(Base): @@ -26,6 +28,9 @@ class Task(Base): party = Column(String(80), nullable=False) args = Column(Text, nullable=True) status = Column(String(80), nullable=False) + start_time = Column(DateTime, nullable=True) + end_time = Column(DateTime, nullable=True) + errors = Column(Text, nullable=True) create_time = Column(DateTime, default=datetime.utcnow) # create_time field update_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) @@ -33,3 +38,35 @@ class Task(Base): version_id = Column(Integer, nullable=False, default=0) __mapper_args__ = {'version_id_col': version_id} __table_args__ = (UniqueConstraint('name', 'job_id', name='uix_1'),) + + def details(self): + details = { + "name": self.name, + "status": self.status, + "start_time": self.start_time or "NA", + "end_time": self.end_time or "NA" + } + if self.status == Status.FAIL and self.errors: + details["errors"] = self.errors + return details + + def reset(self): + self.status = Status.INIT + self.start_time = self.end_time = self.errors = None + + def run(self): + self.status = Status.RUNN + self.start_time = datetime.utcnow() + + def cancel(self): + self.status = Status.CANC + self.end_time = datetime.utcnow() + + def success(self): + self.status = Status.SUCC + self.end_time = datetime.utcnow() + + def fail(self, errors: str = None): + self.status = Status.FAIL + self.end_time = datetime.utcnow() + self.errors = errors or "" diff --git a/src/models/user.py b/src/models/user.py new file mode 100644 index 0000000..3719944 --- /dev/null +++ b/src/models/user.py @@ -0,0 +1,55 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime + +from sqlalchemy import Column, String, DateTime + +from .base import Base, BigIntOrInteger + + +class Status: + normal = "Normal" + revoked = "Revoked" + + +class Role: + operator = "Operator" + node = "Node" + admin = "Admin" + + _roles = [operator, node, admin] + + @classmethod + def validate(cls, role: str): + if role.upper() not in cls._roles: + raise ValueError(f"invalid user role {role}") + return role.upper() + + +class User(Base): + __tablename__ = "privacy_platform_user" + + id = Column(BigIntOrInteger, primary_key=True) + name = Column(String(80), unique=True, nullable=False) + status = Column(String(80), default=Status.normal, nullable=False) + role = Column(String(80), nullable=False) + + create_time = Column(DateTime, default=datetime.utcnow) # create_time field + update_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + return {"name": self.name, "status": self.status, "role": self.role} + + def validate(self): + return self.status == Status.normal diff --git a/src/network/request.py b/src/network/request.py index cd78907..6cbcb9a 100644 --- a/src/network/request.py +++ b/src/network/request.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import time from typing import Dict -import requests - -from .config import network_config +import settings +from network.config import network_config +from utils.request_utils import post, patch class RequestManager: @@ -25,50 +23,49 @@ class RequestManager: def __init__(self): self.party_address: Dict = network_config.party_config - def submit(self, party: str, params: Dict) -> bool: - return self._action(party=party, endpoint="job/submit", json=params) + def _get_address(self, party: str) -> str: + if party not in self.party_address: + raise ValueError(f"invalid party {party}") + return self.party_address.get(party)["address"] - def kill(self, party: str, params: Dict) -> bool: - return self._action(party=party, endpoint="job/kill", json=params) + def _get_headers(self, party: str) -> Dict: + if party not in self.party_address: + raise ValueError(f"invalid party {party}") + headers = {"Authorization": f"Bearer {settings.JWT_TOKEN}"} + headers.update(self.party_address[party].get("headers", {})) + return headers - def rerun(self, party: str, params: Dict) -> bool: - return self._action(party=party, endpoint="job/rerun", json=params) + def submit(self, party: str, params: Dict): + address = self._get_address(party) + headers = self._get_headers(party) + response = post(address, "api/v1/jobs", json=params, headers=headers) + if not response.get("success", False): + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") - def update_task(self, party, params: Dict): - return self._action(party, endpoint="task/update", json=params) + def rerun(self, party: str, job_id: str): + address = self._get_address(party) + headers = self._get_headers(party) + response = post(address, f"api/v1/jobs/{job_id}/rerun", headers=headers) + if not response.get("success", False): + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") - def _action(self, party: str, endpoint: str, json: Dict, max_retry: int = 3, timeout: int = 10) -> bool: - address = self.party_address.get(party)["petplatform"]["url"] - headers = self.party_address.get(party)["petplatform"].get("headers") - for i in range(max_retry): - try: - response = self._post(address, endpoint, json, headers=headers, timeout=timeout) - response.raise_for_status() - if response.status_code == 204: - raise ConnectionError - response_data = response.json() - if not response_data['success']: - raise Exception(response_data["error_message"]) - return True - except Exception: - if i < max_retry - 1: - sleep_time = 0.001 * (2**i) - time.sleep(sleep_time) - raise Exception("request fail") + def cancel(self, party: str, job_id: str): + address = self._get_address(party) + headers = self._get_headers(party) + response = post(address, f"api/v1/jobs/{job_id}/cancel", headers=headers) + if not response.get("success", False): + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") - def _post(self, address: str, endpoint: str, json: Dict, data=None, headers=None, timeout=10): - url = f"{address}/{endpoint}" - post_headers = {"Content-Type": "application/json"} - if headers is not None: - post_headers.update(headers) - try: - logging.info(f"send post request to {url}, json={json}, data={data}, headers={headers}") - response = requests.post(url, json=json, data=data, headers=post_headers, timeout=timeout) - logging.info(f"response: {response.json()}") - return response - except Exception as e: - logging.exception(f"post fail: {address}/{endpoint}, headers={headers}, json={json}, data={data}") - raise e + def update_task(self, party, job_id: str, task_name: str, params: Dict): + address = self._get_address(party) + headers = self._get_headers(party) + response = patch(address, f"api/v1/tasks/{job_id}/{task_name}", json=params, headers=headers) + if not response.get("success", False): + errors = response.get("error_message", "unknown errors") + raise Exception(f"bad request: {errors}") request_manager = RequestManager() diff --git a/src/settings.py b/src/settings.py index 0d84afd..5b301e0 100644 --- a/src/settings.py +++ b/src/settings.py @@ -14,7 +14,7 @@ import os import platform -# --------------------------------------- Logging --------------------------------------- +# ============================ LOGGING =================================== LOGGING_CONFIG = { 'version': 1, 'disable_existing_loggers': False, @@ -42,7 +42,10 @@ } } +# ========================= DB ====================================== PLATFORM_DB_URI = os.environ.get("PLATFORM_DB_URI", "sqlite:////app/db/petplatform.db") + +# ========================= party =================================== PARTY: str = os.environ.get("PARTY") if PARTY is None: raise EnvironmentError("env PARTY not found") @@ -52,3 +55,10 @@ NETWORK_SCHEME = os.environ.get("NETWORK_SCHEME", "agent") PORT_LOWER_BOUND = int(os.environ.get("PORT_LOWER_BOUND", "49152")) PORT_UPPER_BOUND = int(os.environ.get("PORT_UPPER_BOUND", "65535")) + +# ========================= validation ============================== +SECRET = os.environ.get("SECRET") +JWT_TOKEN = os.environ.get("JWT_TOKEN") + +# ========================= application ============================= +MAX_JOB_LIMIT = int(os.environ.get("MAX_JOB_LIMIT", "2")) diff --git a/src/utils/request_utils.py b/src/utils/request_utils.py new file mode 100644 index 0000000..eb35d4f --- /dev/null +++ b/src/utils/request_utils.py @@ -0,0 +1,106 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import requests +from requests.exceptions import Timeout + + +def send_request(method: str, + address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + url = "{address}/{endpoint}".format(address=address, endpoint=endpoint.lstrip('/')) + request_headers = {"Content-Type": "application/json"} + if headers: + request_headers.update(headers) + try: + logging.debug(f"send {method} request to {url}, params={params}, json={json}, data={data}, headers={headers}") + response = requests.request(method, + url, + params=params, + json=json, + data=data, + headers=request_headers, + timeout=timeout) + except Timeout: + logging.error(f"{method} request timed out: {address}/{endpoint}, headers={headers}, json={json}, data={data}") + raise + except Exception: + logging.exception(f"{method} fail: {address}/{endpoint}, headers={headers}, json={json}, data={data}") + raise + + if response.status_code >= 400: + logging.error(f"{method} request error status {response.status_code}: " + f"{address}/{endpoint}, " + f"headers={headers}, " + f"json={json}, " + f"data={data}") + response.raise_for_status() + + logging.debug(f"response: {response.json()}") + return response.json() if return_json else response + + +def delete(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + timeout=10, + return_json=True): + return send_request("DELETE", address, endpoint, params, headers, json, None, timeout, return_json) + + +def get(address: str, endpoint: str, params: Dict = None, headers: Dict = None, timeout=10, return_json=False): + return send_request("GET", address, endpoint, params, headers, None, None, timeout, return_json=return_json) + + +def patch(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("PATCH", address, endpoint, params, headers, json, data, timeout, return_json) + + +def post(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("POST", address, endpoint, params, headers, json, data, timeout, return_json) + + +def put(address: str, + endpoint: str, + params: Dict = None, + headers: Dict = None, + json: Dict = None, + data=None, + timeout=10, + return_json=True): + return send_request("PUT", address, endpoint, params, headers, json, data, timeout, return_json) diff --git a/src/views/__init__.py b/src/views/__init__.py new file mode 100644 index 0000000..4d6166b --- /dev/null +++ b/src/views/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/views.py b/src/views/default_views.py similarity index 86% rename from src/views.py rename to src/views/default_views.py index bebdda9..0747782 100644 --- a/src/views.py +++ b/src/views/default_views.py @@ -19,23 +19,23 @@ import settings from utils.id_utils import generate_job_id -views = Blueprint('default_views', __name__) +default_views = Blueprint('default_views', __name__) -@views.route("/") +@default_views.route("/") def index_view(): return jsonify({ "message": f"{settings.PARTY} app server is running!", }), 200 -@views.route("/job/submit", methods=["POST"]) +@default_views.route("/job/submit", methods=["POST"]) def submit_job(): try: params = request.json job_id = params.get("job_id", generate_job_id()) job_manager = JobManager(job_id) - job_manager.submit_job(params) + job_manager.submit(params) return jsonify({"success": True, "job_id": job_id}), 200 except Exception as e: @@ -43,13 +43,13 @@ def submit_job(): return jsonify({"success": False, "error_message": str(e)}), 500 -@views.route("/job/kill", methods=["POST"]) -def kill_job(): +@default_views.route("/job/rerun", methods=["POST"]) +def rerun_job(): try: params = request.json job_id = params["job_id"] job_manager = JobManager(job_id) - job_manager.kill() + job_manager.rerun() return jsonify({"success": True}), 200 except Exception as e: @@ -57,13 +57,13 @@ def kill_job(): return jsonify({"success": False, "error_message": str(e)}), 500 -@views.route("/job/rerun", methods=["POST"]) -def rerun_job(): +@default_views.route("/job/kill", methods=["POST"]) +def kill_job(): try: params = request.json job_id = params["job_id"] job_manager = JobManager(job_id) - job_manager.rerun() + job_manager.cancel() return jsonify({"success": True}), 200 except Exception as e: @@ -71,20 +71,20 @@ def rerun_job(): return jsonify({"success": False, "error_message": str(e)}), 500 -@views.route("/job/status", methods=["GET"]) +@default_views.route("/job/status", methods=["GET"]) def get_status(): try: params = request.args job_id = params["job_id"] job_manager = JobManager(job_id) - status = job_manager.get_status() + status = job_manager.get_job_details() return jsonify({"success": True, "status": status}), 200 except Exception as e: logging.exception(f"get job status error: {e}") return jsonify({"success": False, "error_message": str(e)}), 500 -@views.route("/task/update", methods=["POST"]) +@default_views.route("/task/update", methods=["POST"]) def update_task(): try: params = request.json diff --git a/src/views/v1.py b/src/views/v1.py new file mode 100644 index 0000000..47778c3 --- /dev/null +++ b/src/views/v1.py @@ -0,0 +1,102 @@ +# Copyright 2024 TikTok Pte. Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flask import g, request, jsonify, Blueprint + +from constants import Status +from decorators.decorators import jwt_required, is_node, check_job_permission, log_and_handle_exceptions +from job_manager.core import JobManager +from utils.id_utils import generate_job_id + +v1 = Blueprint('v1_views', __name__) + + +@v1.route("/api/v1/jobs", methods=["POST"]) +@log_and_handle_exceptions +@jwt_required +def submit(): + user_name = g.validated_user["name"] + params = request.json + job_id = params.get("job_id", generate_job_id()) + job_manager = JobManager(job_id) + job_manager.submit(params, user_name) + + return jsonify({"success": True, "job_id": job_id}), 200 + + +@v1.route("/api/v1/jobs//rerun", methods=["POST"]) +@log_and_handle_exceptions +@jwt_required +@check_job_permission +def rerun(job_id): + job_manager = JobManager(job_id) + job_manager.rerun() + return jsonify({"success": True}), 200 + + +@v1.route("/api/v1/jobs//cancel", methods=["POST"]) +@log_and_handle_exceptions +@jwt_required +@check_job_permission +def cancel(job_id): + job_manager = JobManager(job_id) + job_manager.cancel() + return jsonify({"success": True}), 200 + + +@v1.route("/api/v1/jobs/", methods=["GET"]) +@log_and_handle_exceptions +@jwt_required +@check_job_permission +def get(job_id): + job_manager = JobManager(job_id) + job_details = job_manager.get_job_details() + return jsonify({"success": True, "job": job_details}), 200 + + +@v1.route("/api/v1/jobs", methods=["GET"]) +@log_and_handle_exceptions +@jwt_required +def get_all(): + user_name = g.validated_user["name"] + args = request.args + status = args.get("status") + if status is not None: + status = Status.validate(status) + hours = args.get("hours") + if hours is not None: + hours = int(hours) + if hours < 1: + raise ValueError("hours must be a positive integer") + limit = args.get("limit", "10") + if limit is not None: + limit = int(limit) + if limit < 1: + raise ValueError("limit must be a positive integer") + job_manager = JobManager("") + jobs = job_manager.get_jobs(user_name=user_name, status=status, hours=hours, limit=limit) + return jsonify({"success": True, "jobs": jobs}), 200 + + +@v1.route("/api/v1/tasks//", methods=["PATCH"]) +@log_and_handle_exceptions +@jwt_required +@is_node +def update_task(job_id, task_name): + params = request.json + task_status = params["task_status"] + job_context = params.get("job_context") + errors = params.get("errors") + job_manager = JobManager(job_id) + job_manager.update_task(task_name=task_name, task_status=task_status, external_context=job_context, errors=errors) + return jsonify({"success": True}), 200 diff --git a/test/operators/test_executor.py b/test/operators/test_executor.py index 2f08fc2..506e65e 100644 --- a/test/operators/test_executor.py +++ b/test/operators/test_executor.py @@ -14,7 +14,7 @@ import unittest from job_manager.dag import LogicTask -from task_executor.core import TaskExecutor +from job_manager.task import TaskExecutor from constants import Status