Skip to content

Commit 0419997

Browse files
authored
Add vis_mode to generate_patch_set (#111)
1 parent 4d0cdff commit 0419997

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

ngclearn/utils/patch_utils.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create_patches(self, add_frame=False, center=True):
118118

119119

120120

121-
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234): ## scikit
121+
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False): ## scikit
122122
"""
123123
Generates a set of patches from an array/list of image arrays (via
124124
random sampling with replacement). This uses scikit-learn's patch creation
@@ -151,10 +151,16 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True,
151151
p_batch = np.concatenate((p_batch,patches),axis=0)
152152
else:
153153
p_batch = patches
154+
155+
mu = 0
154156
if center: ## center patches by subtracting out their means
155157
mu = np.mean(p_batch, axis=1, keepdims=True)
156158
p_batch = p_batch - mu
157-
return jnp.array(p_batch)
159+
if vis_mode:
160+
return jnp.array(p_batch), mu
161+
else:
162+
return jnp.array(p_batch)
163+
158164

159165
def generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True): ## patchify
160166
## this is a patchify-specific function (only use if you have patchify installed...)

0 commit comments

Comments
 (0)