@@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
218
218
fname += std::to_string (ncall);
219
219
}
220
220
221
+ // avoid writing imatrix entries that do not have full data
222
+ // this can happen with MoE models where some of the experts end up not being exercised by the provided training data
223
+
224
+ int n_entries = 0 ;
225
+ std::vector<std::string> to_store;
226
+
227
+ bool is_first = true ; // for printing
228
+ for (const auto & kv : m_stats) {
229
+ const int n_all = kv.second .counts .size ();
230
+
231
+ if (n_all == 0 ) {
232
+ continue ;
233
+ }
234
+
235
+ int n_zeros = 0 ;
236
+ for (const int c : kv.second .counts ) {
237
+ if (c == 0 ) {
238
+ n_zeros++;
239
+ }
240
+ }
241
+
242
+ if (n_zeros != 0 && is_first) {
243
+ fprintf (stderr, " \n " );
244
+ is_first = false ;
245
+ }
246
+
247
+ if (n_zeros == n_all) {
248
+ fprintf (stderr, " %s: entry '%40s' has no data - skipping\n " , __func__, kv.first .c_str ());
249
+ continue ;
250
+ }
251
+
252
+ if (n_zeros > 0 ) {
253
+ fprintf (stderr, " %s: entry '%40s' has partial data (%.2f%%) - skipping\n " , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
254
+ continue ;
255
+ }
256
+
257
+ n_entries++;
258
+ to_store.push_back (kv.first );
259
+ }
260
+
261
+ if (to_store.size () < m_stats.size ()) {
262
+ fprintf (stderr, " %s: warning: storing only %zu out of %zu entries\n " , __func__, to_store.size (), m_stats.size ());
263
+ }
264
+
221
265
std::ofstream out (fname, std::ios::binary);
222
- int n_entries = m_stats.size ();
223
266
out.write ((const char *) &n_entries, sizeof (n_entries));
224
- for (const auto & p : m_stats) {
225
- int len = p.first .size ();
267
+ for (const auto & name : to_store) {
268
+ const auto & stat = m_stats.at (name);
269
+ int len = name.size ();
226
270
out.write ((const char *) &len, sizeof (len));
227
- out.write (p. first .c_str (), len);
228
- out.write ((const char *) &p. second . ncall , sizeof (p. second .ncall ));
229
- int nval = p. second .values .size ();
271
+ out.write (name .c_str (), len);
272
+ out.write ((const char *) &stat. ncall , sizeof (stat .ncall ));
273
+ int nval = stat .values .size ();
230
274
out.write ((const char *) &nval, sizeof (nval));
231
275
if (nval > 0 ) {
232
276
std::vector<float > tmp (nval);
233
277
for (int i = 0 ; i < nval; i++) {
234
- tmp[i] = (p. second . values [i] / static_cast <float >(p. second . counts [i])) * static_cast <float >(p. second .ncall );
278
+ tmp[i] = (stat. values [i] / static_cast <float >(stat. counts [i])) * static_cast <float >(stat .ncall );
235
279
}
236
280
out.write ((const char *)tmp.data (), nval*sizeof (float ));
237
281
}
0 commit comments