diff --git a/README.md b/README.md index ac17f367..241bd4f9 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,11 @@ on the GPU or CPU at different points in the pipeline.

Shapes3D Dataset Factor Traversals

+ +
+ 🏹 Sprites (custom) +

Sprites (Custom) Dataset Factor Traversals

+
+ +
🧵 dSpritesImagenet: diff --git a/disent/dataset/data/_groundtruth__sprites.py b/disent/dataset/data/_groundtruth__sprites.py index 49e2c5ec..d3f6eb54 100644 --- a/disent/dataset/data/_groundtruth__sprites.py +++ b/disent/dataset/data/_groundtruth__sprites.py @@ -46,16 +46,22 @@ # ========================================================================= # +SPRITES_REPO = 'https://github.com/YingzhenLi/Sprites' +SPRITES_REPO_COMMIT_SHA = '3ce4048c5227802bd8f1888e293fd3afdba91c0c' + + def fetch_sprite_components() -> Tuple[np.array, np.array]: try: import git except ImportError: - logging.error('GitPython not found! Please install it: `pip install GitPython`') + log.error('GitPython not found! Please install it: `pip install GitPython`') exit(1) # store files in a temporary directory with TemporaryDirectory(suffix='sprites') as temp_dir: # clone the files into the temp dir - git.Repo.clone_from('https://github.com/YingzhenLi/Sprites', temp_dir) + log.info(f'Generating sprites data, temporarily cloning: {SPRITES_REPO} to {temp_dir}`') + repo = git.Repo.clone_from(SPRITES_REPO, temp_dir, no_checkout=True) + repo.git.checkout(SPRITES_REPO_COMMIT_SHA) # get all the components! component_sheets: List[np.ndarray] = [] component_names = ['bottomwear', 'topwear', 'hair', 'eyes', 'shoes', 'body'] @@ -111,6 +117,11 @@ def _prepare(self, out_dir: str, out_file: str) -> NoReturn: class SpritesAllData(DiskGroundTruthData): + """ + Custom version of sprites, with the data obtained from: + https://github.com/YingzhenLi/Sprites + """ + name = 'sprites' factor_names = ('bottomwear', 'topwear', 'hair', 'eyes', 'shoes', 'body', 'action', 'rotation', 'frame') factor_sizes = (7, 7, 10, 5, 3, 7, 5, 4, 6) # 6_174_000 diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py index 99f3c1d6..781c95b5 100644 --- a/disent/dataset/util/stats.py +++ b/disent/dataset/util/stats.py @@ -210,6 +210,7 @@ def main(progress=True, num_workers=32, batch_size=2048): # try changing worker # mean: [0.046146392822265625] # std: [0.2096506119375896] +# TODO -- REGENERATE: # Mpi3dData - mpi3d_toy - {'subset': 'toy', 'in_memory': True}: # mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702] # std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426] diff --git a/docs/img/traversals/traversal-transpose__sprites.jpg b/docs/img/traversals/traversal-transpose__sprites.jpg new file mode 100644 index 00000000..c3863f33 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__sprites.jpg differ