|
| 1 | +#!/usr/bin/python |
| 2 | + |
| 3 | +# |
| 4 | +# Copyright 2024 Hopsworks AB |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | +# |
| 18 | + |
| 19 | +"""Scripts for automatic management of aliases.""" |
| 20 | + |
| 21 | +import importlib |
| 22 | +import sys |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | + |
| 26 | +SOURCES = [ |
| 27 | + "hopsworks/__init__.py", |
| 28 | + "hopsworks/connection.py", |
| 29 | + "hopsworks/internal", |
| 30 | + "hopsworks/platform", |
| 31 | + "hopsworks/fs", |
| 32 | + "hopsworks/ml", |
| 33 | +] |
| 34 | +IGNORED = ["tests", "hsfs", "hopsworks", "hsml", "hopsworks_common"] |
| 35 | +# Everything that is not a top-level file, a part of sources, or a part of ignored is considered to be autmoatically managed. |
| 36 | + |
| 37 | + |
| 38 | +def collect_imports(root): |
| 39 | + imports = [] |
| 40 | + |
| 41 | + def imports_add(file): |
| 42 | + pkg = str(file.parent.relative_to(root)).replace("/", ".") |
| 43 | + if file.name == "__init__.py": |
| 44 | + imports.append(pkg) |
| 45 | + elif file.name.endswith(".py"): |
| 46 | + imports.append(pkg + "." + file.name[:-3]) |
| 47 | + |
| 48 | + for source in SOURCES: |
| 49 | + if (root / source).is_file(): |
| 50 | + imports_add(root / source) |
| 51 | + continue |
| 52 | + for dirpath, _, filenames in (root / source).walk(): |
| 53 | + for filename in filenames: |
| 54 | + imports_add(dirpath / filename) |
| 55 | + return imports |
| 56 | + |
| 57 | + |
| 58 | +def collect_aliases(root): |
| 59 | + for import_str in collect_imports(root): |
| 60 | + importlib.import_module(import_str, package=".") |
| 61 | + aliases = importlib.import_module("hopsworks.internal.aliases", package=".") |
| 62 | + return aliases._aliases |
| 63 | + |
| 64 | + |
| 65 | +def collect_managed(root): |
| 66 | + managed = {} |
| 67 | + for pkg, from_imports in collect_aliases(root).items(): |
| 68 | + pkg = root / pkg.replace(".", "/") / "__init__.py" |
| 69 | + managed[pkg] = ( |
| 70 | + "# ruff: noqa\n" |
| 71 | + "# This file is generated by aliases.py. Do not edit it manually!\n" |
| 72 | + ) |
| 73 | + from_imports.sort() # this is needed for determinism |
| 74 | + for f, i in from_imports: |
| 75 | + managed[pkg] += f"from {f} import {i}\n" |
| 76 | + return managed |
| 77 | + |
| 78 | + |
| 79 | +def fix(root): |
| 80 | + managed = collect_managed(root) |
| 81 | + for filepath, content in managed.items(): |
| 82 | + filepath.parent.mkdir(parents=True, exist_ok=True) |
| 83 | + filepath.touch() |
| 84 | + filepath.write_text(content) |
| 85 | + ignored = [root / path for path in SOURCES + IGNORED] |
| 86 | + for dirpath, _, filenames in root.walk(): |
| 87 | + if dirpath == root: |
| 88 | + continue |
| 89 | + for filename in filenames: |
| 90 | + filepath = dirpath / filename |
| 91 | + if any(filepath.is_relative_to(p) for p in ignored): |
| 92 | + continue |
| 93 | + if filepath not in managed: |
| 94 | + filepath.unlink() |
| 95 | + |
| 96 | + |
| 97 | +def check(root): |
| 98 | + ok = True |
| 99 | + managed = collect_managed(root) |
| 100 | + ignored = [root / path for path in SOURCES + IGNORED] |
| 101 | + for dirpath, _, filenames in root.walk(): |
| 102 | + if dirpath == root: |
| 103 | + continue |
| 104 | + for filename in filenames: |
| 105 | + filepath = dirpath / filename |
| 106 | + if any(filepath.is_relative_to(p) for p in ignored): |
| 107 | + continue |
| 108 | + if filepath not in managed: |
| 109 | + print(f"Error: {filepath} shouldn't exist.") |
| 110 | + ok = False |
| 111 | + continue |
| 112 | + if filepath.read_text() != managed[filepath]: |
| 113 | + print(f"Error: {filepath} has wrong content.") |
| 114 | + ok = False |
| 115 | + if ok: |
| 116 | + print("The aliases are correct!") |
| 117 | + else: |
| 118 | + print("To fix the errors, run `aliases.py fix`.") |
| 119 | + exit(1) |
| 120 | + |
| 121 | + |
| 122 | +def help(msg=None): |
| 123 | + if msg: |
| 124 | + print(msg + "\n") |
| 125 | + print("Use `aliases.py fix [path]` or `aliases.py check [path]`.") |
| 126 | + print( |
| 127 | + "`path` is optional, current directory (or its `python` subdirectory) is used by default; it should be the directory containing the hopsworks package, e.g., `./python/`." |
| 128 | + ) |
| 129 | + exit(1) |
| 130 | + |
| 131 | + |
| 132 | +def main(): |
| 133 | + if len(sys.argv) == 3: |
| 134 | + root = Path(sys.argv[2]) |
| 135 | + elif len(sys.argv) == 2: |
| 136 | + root = Path() |
| 137 | + if not (root / "hopsworks").exists(): |
| 138 | + root = root / "python" |
| 139 | + else: |
| 140 | + help("Wrong number of arguments.") |
| 141 | + |
| 142 | + root = root.resolve() |
| 143 | + if not (root / "hopsworks").exists(): |
| 144 | + help("The used path doesn't contain the hopsworks package.") |
| 145 | + |
| 146 | + cmd = sys.argv[1] |
| 147 | + if cmd in ["f", "fix"]: |
| 148 | + cmd = fix |
| 149 | + elif cmd in ["c", "check"]: |
| 150 | + cmd = check |
| 151 | + else: |
| 152 | + help("Unknown command.") |
| 153 | + |
| 154 | + cmd(root) |
| 155 | + |
| 156 | + |
| 157 | +if __name__ == "__main__": |
| 158 | + main() |
0 commit comments