Skip to content

Commit 8b9eabe

Browse files
committed
cr-nimble refactoring related changes
1 parent 9b96a53 commit 8b9eabe

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

examples/pursuit/cosamp_step_by_step.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
import cr.sparse as crs
5353
import cr.sparse.dict as crdict
5454
import cr.sparse.data as crdata
55+
from cr.nimble.dsp import (
56+
nonzero_indices,
57+
nonzero_values,
58+
largest_indices
59+
)
5560

5661
# %%
5762
# Problem Setup
@@ -147,7 +152,7 @@
147152

148153
# %%
149154
# Pick the indices of 3K atoms with largest matches with the residual
150-
I_sub = crs.largest_indices(h, K3)
155+
I_sub = largest_indices(h, K3)
151156
# Update the flags array
152157
flags = flags.at[I_sub].set(True)
153158
# Sort the ``I_sub`` array with the help of flags array
@@ -165,7 +170,7 @@
165170
# Compute the least squares solution of ``y`` over this subdictionary
166171
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
167172
# Pick the indices of K largest entries in in ``x_sub``
168-
Ia = crs.largest_indices(x_sub, K)
173+
Ia = largest_indices(x_sub, K)
169174
print(f"{Ia=}")
170175
# %%
171176
# We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``
@@ -215,7 +220,7 @@
215220
h = Phi.T @ r
216221
# %%
217222
# Pick the indices of 2K atoms with largest matches with the residual
218-
I_2k = crs.largest_indices(h, K2 if iterations else K3)
223+
I_2k = largest_indices(h, K2 if iterations else K3)
219224
# We can check if these include the atoms missed out in first iteration.
220225
print(jnp.intersect1d(omega, I_2k))
221226
# %%
@@ -237,7 +242,7 @@
237242
# Compute the least squares solution of ``y`` over this subdictionary
238243
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
239244
# Pick the indices of K largest entries in in ``x_sub``
240-
Ia = crs.largest_indices(x_sub, K)
245+
Ia = largest_indices(x_sub, K)
241246
print(Ia)
242247
# %%
243248
# We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``

examples/pursuit/cs1bit_biht.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import cr.sparse.data as crdata
2323
import cr.sparse.cs.cs1bit as cs1bit
2424

25+
from cr.nimble.dsp import (
26+
build_signal_from_indices_and_values
27+
)
2528

2629
# %%
2730
# Setup
@@ -76,7 +79,7 @@
7679
sol = cs1bit.biht_jit(Phi, y, K, tau)
7780
# %%
7881
# reconstructed signal
79-
x_rec = crs.build_signal_from_indices_and_values(N, sol.I, sol.x_I)
82+
x_rec = build_signal_from_indices_and_values(N, sol.I, sol.x_I)
8083

8184
# %%
8285
# Verification

examples/rec_l1/spikes_l1ls.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
import cr.sparse.data as crdata
2828
import cr.sparse.lop as lop
2929
import cr.sparse.cvx.l1ls as l1ls
30+
from cr.nimble.dsp import (
31+
hard_threshold_by,
32+
support,
33+
largest_indices_by
34+
)
3035

3136
# %%
3237
# Setup
@@ -114,7 +119,7 @@
114119
# %%
115120
# Thresholding for large values
116121
# '''''''''''''''''''''''''''''''''''''
117-
x = crs.hard_threshold_by(sol.x, 0.5)
122+
x = hard_threshold_by(sol.x, 0.5)
118123
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
119124
plt.subplot(211)
120125
plt.plot(xs)
@@ -124,8 +129,8 @@
124129
# %%
125130
# Verifying the support recovery
126131
# '''''''''''''''''''''''''''''''''''''
127-
support_xs = crs.support(xs)
128-
support_x = crs.support(x)
132+
support_xs = support(xs)
133+
support_x = support(x)
129134
jnp.all(jnp.equal(support_xs, support_x))
130135

131136

@@ -134,7 +139,7 @@
134139
# ------------------------------------------------
135140

136141
# Identify the sub-matrix of columns for the support of recovered solution's large entries
137-
support_x = crs.largest_indices_by(sol.x, 0.5)
142+
support_x = largest_indices_by(sol.x, 0.5)
138143
AI = A.columns(support_x)
139144
print(AI.shape)
140145

examples/sparse_vector_normals.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import jax.numpy as jnp
1515
import cr.sparse as crs
1616
import cr.sparse.data as crdata
17+
from cr.nimble.dsp import (
18+
nonzero_indices,
19+
nonzero_values
20+
)
1721

1822
# %%
1923
# Let's define the size of model and number of sparse entries
@@ -31,11 +35,11 @@
3135

3236
# %%
3337
# We can easily find the locations of non-zero entries
34-
print(crs.nonzero_indices(x))
38+
print(nonzero_indices(x))
3539

3640
# %%
3741
# We can extract corresponding non-zero values in a compact vector
38-
print(crs.nonzero_values(x))
42+
print(nonzero_values(x))
3943

4044
# %%
4145
# Let's plot the vector to see where the non-zero entries are

0 commit comments

Comments
 (0)