diff --git a/ctf/__init__.py b/ctf/__init__.py index be3d74b..e40fa6d 100644 --- a/ctf/__init__.py +++ b/ctf/__init__.py @@ -23,19 +23,21 @@ coloredlogs.install(level="DEBUG", logger=LOG) -def check_tool_version(): +def check_tool_version() -> None: with urllib.request.urlopen( url="https://api.github.com/repos/nsec/ctf-script/releases/latest" ) as r: if r.getcode() != 200: LOG.debug(r.read().decode()) LOG.error("Could not verify the latest release.") + return else: try: latest_version = json.loads(s=r.read().decode())["tag_name"] except Exception as e: LOG.debug(e) LOG.error("Could not verify the latest release.") + return compare = 0 for current_part, latest_part in zip( diff --git a/ctf/utils.py b/ctf/utils.py index 1f7069b..40eb098 100644 --- a/ctf/utils.py +++ b/ctf/utils.py @@ -142,7 +142,7 @@ def remove_tracks_from_terraform_modules( ) -def get_all_file_paths_recursively(path: str) -> Generator[None, None, str]: +def get_all_file_paths_recursively(path: str) -> Generator[str, None, None]: if os.path.isfile(path=path): yield remove_ctf_script_root_directory_from_path(path=path) else: diff --git a/ctf/validators.py b/ctf/validators.py index cd7ed38..5b47771 100644 --- a/ctf/validators.py +++ b/ctf/validators.py @@ -166,18 +166,23 @@ def finalize(self) -> list[ValidationError]: class DiscoursePostsAskGodTagValidator(Validator): - """Validate that the triggers used in discourse posts are correctly defined in the discourse tag of each flag in track.yaml. Also validate that each discourse tag is unique. Also validates that the topic matches an existing file name in the posts directory.""" + """ + Validate that the triggers used in discourse posts are correctly defined in the discourse tag of each flag in track.yaml. It checks for triggers in ALL tracks to allow a flow like: "When a flag from track A is triggered, show a post in track B". + Also validate that each discourse tag is unique. + Also validates that the topic matches an existing file name in the posts directory. + """ def __init__(self): self.discourse_tags_mapping = {} + self.discourse_triggers = [] + self.discourse_posts = [] def validate(self, track_name: str) -> list[ValidationError]: track_yaml = parse_track_yaml(track_name=track_name) - discourse_triggers = [] for flag in track_yaml["flags"]: discourse_trigger = flag.get("tags", {}).get("discourse") if discourse_trigger: - discourse_triggers.append(discourse_trigger) + self.discourse_triggers.append(discourse_trigger) if discourse_trigger not in self.discourse_tags_mapping: self.discourse_tags_mapping[discourse_trigger] = [] self.discourse_tags_mapping[discourse_trigger].append(track_yaml) @@ -186,6 +191,7 @@ def validate(self, track_name: str) -> list[ValidationError]: discourse_posts = parse_post_yamls(track_name=track_name) for discourse_post in discourse_posts: if discourse_post.get("trigger", {}).get("type", "") == "flag": + self.discourse_posts.append((track_name, discourse_post)) if not os.path.exists( os.path.join( CTF_ROOT_DIRECTORY, @@ -211,18 +217,6 @@ def validate(self, track_name: str) -> list[ValidationError]: }, ) ) - if discourse_post["trigger"]["tag"] not in discourse_triggers: - errors.append( - ValidationError( - error_name="Invalid trigger in discourse post", - error_description="A discourse post has a flag trigger that references a discourse tag not defined in track.yaml.", - track_name=track_name, - details={ - "Invalid tag": discourse_post["trigger"]["tag"], - "Discourse tags in track.yaml": str(discourse_triggers), - }, - ) - ) return errors @@ -238,6 +232,19 @@ def finalize(self) -> list[ValidationError]: details={'"discourse" tag': discourse_tag}, ) ) + + for track_name, discourse_post in self.discourse_posts: + if discourse_post["trigger"]["tag"] not in self.discourse_triggers: + errors.append( + ValidationError( + error_name="Invalid trigger in discourse post", + error_description="A discourse post has a flag trigger that references a discourse tag not defined in track.yaml.", + track_name=track_name, + details={ + "Invalid tag": discourse_post["trigger"]["tag"], + }, + ) + ) return errors @@ -378,8 +385,8 @@ def validate(self, track_name: str) -> list[ValidationError]: errors: list[ValidationError] = [] services = set() for service in track_yaml["services"]: - service_name = service["name"] - instance_name = service["instance"] + service_name = service.get("name") + instance_name = service.get("instance") service = f"{instance_name}/{service_name}" if service in services: diff --git a/pyproject.toml b/pyproject.toml index c79db9c..4da565c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,14 @@ dependencies = [ "black", "tabulate==0.9.0", ] -version = "1.1.0" +version = "1.1.1" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ] [project.optional-dependencies] -coderunner = ["pybadges", "matplotlib", "standard-imghdr ; python_version >= \"3.13\""] +workflow = ["pybadges", "matplotlib", "standard-imghdr ; python_version >= \"3.13\""] [project.scripts] ctf = "ctf.__main__:main"