Skip to content

Commit e95beeb

Browse files
authored
imatrix : handle partial entries (ggml-org#7833)
1 parent 57bf62c commit e95beeb

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
218218
fname += std::to_string(ncall);
219219
}
220220

221+
// avoid writing imatrix entries that do not have full data
222+
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
223+
224+
int n_entries = 0;
225+
std::vector<std::string> to_store;
226+
227+
bool is_first = true; // for printing
228+
for (const auto & kv : m_stats) {
229+
const int n_all = kv.second.counts.size();
230+
231+
if (n_all == 0) {
232+
continue;
233+
}
234+
235+
int n_zeros = 0;
236+
for (const int c : kv.second.counts) {
237+
if (c == 0) {
238+
n_zeros++;
239+
}
240+
}
241+
242+
if (n_zeros != 0 && is_first) {
243+
fprintf(stderr, "\n");
244+
is_first = false;
245+
}
246+
247+
if (n_zeros == n_all) {
248+
fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
249+
continue;
250+
}
251+
252+
if (n_zeros > 0) {
253+
fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
254+
continue;
255+
}
256+
257+
n_entries++;
258+
to_store.push_back(kv.first);
259+
}
260+
261+
if (to_store.size() < m_stats.size()) {
262+
fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
263+
}
264+
221265
std::ofstream out(fname, std::ios::binary);
222-
int n_entries = m_stats.size();
223266
out.write((const char *) &n_entries, sizeof(n_entries));
224-
for (const auto & p : m_stats) {
225-
int len = p.first.size();
267+
for (const auto & name : to_store) {
268+
const auto & stat = m_stats.at(name);
269+
int len = name.size();
226270
out.write((const char *) &len, sizeof(len));
227-
out.write(p.first.c_str(), len);
228-
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
229-
int nval = p.second.values.size();
271+
out.write(name.c_str(), len);
272+
out.write((const char *) &stat.ncall, sizeof(stat.ncall));
273+
int nval = stat.values.size();
230274
out.write((const char *) &nval, sizeof(nval));
231275
if (nval > 0) {
232276
std::vector<float> tmp(nval);
233277
for (int i = 0; i < nval; i++) {
234-
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
278+
tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
235279
}
236280
out.write((const char*)tmp.data(), nval*sizeof(float));
237281
}

0 commit comments

Comments
 (0)