|
| 1 | +"""epymorph's file caching utilities.""" |
| 2 | +from hashlib import sha256 |
| 3 | +from io import BytesIO |
| 4 | +from os import PathLike |
| 5 | +from pathlib import Path |
| 6 | +from tarfile import TarInfo, is_tarfile |
| 7 | +from tarfile import open as open_tarfile |
| 8 | + |
| 9 | +from platformdirs import user_cache_path |
| 10 | + |
| 11 | +CACHE_PATH = user_cache_path(appname='epymorph', ensure_exists=True) |
| 12 | + |
| 13 | + |
| 14 | +class FileError(Exception): |
| 15 | + """Error during a file operation.""" |
| 16 | + |
| 17 | + |
| 18 | +class FileMissingError(FileError): |
| 19 | + """Error loading a file, as it does not exist.""" |
| 20 | + |
| 21 | + |
| 22 | +class FileWriteError(FileError): |
| 23 | + """Error writing a file.""" |
| 24 | + |
| 25 | + |
| 26 | +class FileReadError(FileError): |
| 27 | + """Error loading a file.""" |
| 28 | + |
| 29 | + |
| 30 | +class FileVersionError(FileError): |
| 31 | + """Error loading a file due to unmet version requirements.""" |
| 32 | + |
| 33 | + |
| 34 | +class CacheMiss(FileError): |
| 35 | + """Raised on a cache-miss (for any reason) during a load-from-cache operation.""" |
| 36 | + |
| 37 | + |
| 38 | +def save_bundle(to_path: str | PathLike[str], version: int, files: dict[str, BytesIO]) -> None: |
| 39 | + """ |
| 40 | + Save a bundle of files in our tar format with an associated version number. |
| 41 | + `to_path` can be absolute or relative; relative paths will be resolved |
| 42 | + against the current working directory. Folders in the path which do not exist |
| 43 | + will be created automatically. |
| 44 | + """ |
| 45 | + |
| 46 | + if version <= 0: |
| 47 | + raise ValueError("version should be greater than zero.") |
| 48 | + |
| 49 | + try: |
| 50 | + # Compute checksums |
| 51 | + sha_entries = [] |
| 52 | + for name, contents in files.items(): |
| 53 | + contents.seek(0) |
| 54 | + sha = sha256() |
| 55 | + sha.update(contents.read()) |
| 56 | + sha_entries.append(f"{sha.hexdigest()} {name}") |
| 57 | + |
| 58 | + # Create checksums.sha256 file |
| 59 | + sha_file = BytesIO() |
| 60 | + sha_text = "\n".join(sha_entries) |
| 61 | + sha_file.write(bytes(sha_text, encoding='utf-8')) |
| 62 | + |
| 63 | + # Create cache version file |
| 64 | + ver_file = BytesIO() |
| 65 | + ver_file.write(bytes(str(version), encoding="utf-8")) |
| 66 | + |
| 67 | + tarred_files = { |
| 68 | + **files, |
| 69 | + "checksums.sha256": sha_file, |
| 70 | + "version": ver_file, |
| 71 | + } |
| 72 | + |
| 73 | + # Write the tar to disk |
| 74 | + tar_path = Path(to_path).resolve() |
| 75 | + tar_path.parent.mkdir(parents=True, exist_ok=True) |
| 76 | + mode = 'w:gz' if tar_path.suffix == '.tgz' else 'w' |
| 77 | + with open_tarfile(name=tar_path, mode=mode) as tar: |
| 78 | + for name, contents in tarred_files.items(): |
| 79 | + info = TarInfo(name) |
| 80 | + info.size = contents.tell() |
| 81 | + contents.seek(0) |
| 82 | + tar.addfile(info, contents) |
| 83 | + |
| 84 | + except Exception as e: |
| 85 | + msg = f"Unable to write archive at path: {to_path}" |
| 86 | + raise FileWriteError(msg) from e |
| 87 | + |
| 88 | + |
| 89 | +def load_bundle(from_path: str | PathLike[str], version_at_least: int = -1) -> dict[str, BytesIO]: |
| 90 | + """ |
| 91 | + Load a bundle of files in our tar format, optionally enforcing a minimum version. |
| 92 | + An Exception is raised if the file cannot be loaded for any reason, or if its version |
| 93 | + is incorrect. On success, returns a dictionary of the contained files, mapping the file |
| 94 | + name to the bytes of the file. |
| 95 | + """ |
| 96 | + try: |
| 97 | + tar_path = Path(from_path).resolve() |
| 98 | + if not tar_path.is_file(): |
| 99 | + raise FileMissingError(f"No file at: {tar_path}") |
| 100 | + |
| 101 | + # Read the tar file into memory |
| 102 | + tar_buffer = BytesIO() |
| 103 | + with open(tar_path, 'rb') as f: |
| 104 | + tar_buffer.write(f.read()) |
| 105 | + tar_buffer.seek(0) |
| 106 | + |
| 107 | + if not is_tarfile(tar_buffer): |
| 108 | + raise FileReadError(f"Not a tar file at: {tar_path}") |
| 109 | + |
| 110 | + mode = 'r:gz' if tar_path.suffix == '.tgz' else 'r' |
| 111 | + tarred_files: dict[str, BytesIO] = {} |
| 112 | + with open_tarfile(fileobj=tar_buffer, mode=mode) as tar: |
| 113 | + for info in tar.getmembers(): |
| 114 | + name = info.name |
| 115 | + contents = tar.extractfile(info) |
| 116 | + if contents is not None: |
| 117 | + tarred_files[name] = BytesIO(contents.read()) |
| 118 | + |
| 119 | + # Check version |
| 120 | + if "version" in tarred_files: |
| 121 | + ver_file = tarred_files["version"] |
| 122 | + version = int(str(ver_file.readline(), encoding="utf-8")) |
| 123 | + else: |
| 124 | + version = -1 |
| 125 | + if version < version_at_least: |
| 126 | + raise FileVersionError("Archive is an unacceptable version.") |
| 127 | + |
| 128 | + # Verify the checksums |
| 129 | + if "checksums.sha256" not in tarred_files: |
| 130 | + raise FileReadError("Archive appears to be invalid.") |
| 131 | + sha_file = tarred_files["checksums.sha256"] |
| 132 | + for line_bytes in sha_file.readlines(): |
| 133 | + line = str(line_bytes, encoding='utf-8') |
| 134 | + [checksum, filename] = line.strip().split(' ') |
| 135 | + |
| 136 | + if filename not in tarred_files: |
| 137 | + raise FileReadError("Archive appears to be invalid.") |
| 138 | + |
| 139 | + contents = tarred_files[filename] |
| 140 | + contents.seek(0) |
| 141 | + sha = sha256() |
| 142 | + sha.update(contents.read()) |
| 143 | + contents.seek(0) |
| 144 | + if checksum != sha.hexdigest(): |
| 145 | + msg = f"Archive checksum did not match (for file {filename}). "\ |
| 146 | + "It is possible the file is corrupt." |
| 147 | + raise FileReadError(msg) |
| 148 | + |
| 149 | + return { |
| 150 | + name: contents |
| 151 | + for name, contents in tarred_files.items() |
| 152 | + if name not in ("checksums.sha256", "version") |
| 153 | + } |
| 154 | + |
| 155 | + except FileError: |
| 156 | + raise |
| 157 | + except Exception as e: |
| 158 | + raise FileReadError(f"Unable to load archive at: {from_path}") from e |
| 159 | + |
| 160 | + |
| 161 | +def _resolve_cache_path(path: str | PathLike[str]) -> Path: |
| 162 | + cache_path = Path(path) |
| 163 | + if cache_path.is_absolute(): |
| 164 | + msg = "When saving to or loading from the cache, please supply a relative path." |
| 165 | + raise ValueError(msg) |
| 166 | + return CACHE_PATH.joinpath(cache_path).resolve() |
| 167 | + |
| 168 | + |
| 169 | +def save_bundle_to_cache(to_path: str | PathLike[str], version: int, files: dict[str, BytesIO]) -> None: |
| 170 | + """ |
| 171 | + Save a tar bundle of files to the cache (overwriting the existing file, if any). |
| 172 | + The tar includes the sha256 checksums of every content file, |
| 173 | + and a version file indicating which application version was |
| 174 | + responsible for writing the file (thus allowing the application |
| 175 | + to decide if a cached file is still valid when reading it). |
| 176 | + """ |
| 177 | + save_bundle(_resolve_cache_path(to_path), version, files) |
| 178 | + |
| 179 | + |
| 180 | +def load_bundle_from_cache(from_path: str | PathLike[str], version_at_least: int = -1) -> dict[str, BytesIO]: |
| 181 | + """ |
| 182 | + Load a tar bundle of files from the cache. `from_path` must be a relative path. |
| 183 | + `version_at_least` optionally specifies a version number that must be met or beat |
| 184 | + by the cached file in order for the file to be considered valid. If the cached file |
| 185 | + was written against a version less than this, it will be considered a cache miss |
| 186 | + (raises CacheMiss). |
| 187 | + """ |
| 188 | + try: |
| 189 | + return load_bundle(_resolve_cache_path(from_path), version_at_least) |
| 190 | + except FileError as e: |
| 191 | + raise CacheMiss() from e |
0 commit comments