Skip to content

Commit 389c131

Browse files
yanwei-gryanweijackzipu
authored
keep load_or_save_cache when destoryComputation (#698)
* add potential exception handle & open session_option to optimize * update odla_computation::init() to return odla_status value * fix popart_config load error * improve load config logic * Check the sdk version of the cache * Format the codes to pass lint checking * reset other necessary popart config value * add return value in computation::init() * clear cache_fs before write new data * clear fstream after close * fix lint error * fix error message issue * keep load_or_save_cache value * use popart::core::packageHash for version info update some log type to err & get sdk packagehash in use_default() use assign instead of =& remove useless logic close log Co-authored-by: yanwei <yw01041751@alibaba-inc.com> Co-authored-by: gcuser <jackz@graphcore.ai>
1 parent 9d51f07 commit 389c131

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

ODLA/platforms/odla_popart/odla_compute.cc

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <ODLA/odla.h>
2020
#include <dlfcn.h>
21+
#include <stdlib.h>
2122

2223
#include <cstdlib>
2324
#include <fstream>
@@ -167,6 +168,7 @@ odla_status odla_DestroyContext(odla_context ctx) {
167168
}
168169

169170
odla_status odla_DestroyComputation(odla_computation comp) {
171+
popart::logging::info("call odla_destroyComputation");
170172
if (comp != nullptr) {
171173
if (!comp->is_compile_only()) {
172174
comp->mark_done();

ODLA/platforms/odla_popart/odla_popart.cc

+9-7
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ void compute_loop(odla_computation comp) {
7979
popart::logging::err("Poplar unrecoverable_runtime_error exception caught");
8080
QManager::instance()->set_status(ODLA_UNRECOVERABLE_ERR);
8181
} catch (poplar::unknown_runtime_error& e) {
82-
popart::logging::info("Poplar unknown runtime exception caught");
82+
popart::logging::err("Poplar unknown runtime exception caught");
8383
QManager::instance()->set_status(ODLA_UNRECOVERABLE_ERR);
8484
} catch (...) {
85-
popart::logging::info("Poplar unknown exception caught");
85+
popart::logging::err("Poplar unknown exception caught");
8686
QManager::instance()->set_status(ODLA_UNRECOVERABLE_ERR);
8787
}
8888

89-
popart::logging::warn("The pipeline loop finished");
89+
popart::logging::info("The pipeline loop finished");
9090
comp->thread_done();
9191
}
9292

@@ -134,10 +134,12 @@ odla_status _odla_computation::compile_and_export() {
134134
config_string = PopartConfig::instance()->get_default_config_string();
135135
}
136136
// add sdk_version in the file content
137-
std::string version_string(popart::core::versionString());
137+
std::string version_string(popart::core::packageHash());
138138
popart::logging::info("the popart version is: {}", version_string);
139-
version_string = "\n\"sdk_version\":\"" + version_string + "\",";
140-
config_string.insert(1, version_string);
139+
if (config_string.find("sdk_version") == std::string::npos) {
140+
std::string item_string = "\n\"sdk_version\":\"" + version_string + "\",";
141+
config_string.insert(1, item_string);
142+
}
141143
popart::logging::info("the config_string with sdk_version is: {}",
142144
config_string);
143145
// added the sdk_version information to the file content
@@ -236,7 +238,7 @@ odla_status _odla_computation::init(bool is_compile) {
236238
if (!is_compile) {
237239
if (PopartConfig::instance()->load_or_save_cache()) {
238240
popart::logging::info("Load cachefile from existing stream");
239-
std::string version_string(popart::core::versionString());
241+
std::string version_string(popart::core::packageHash());
240242
if (!PopartConfig::instance()->sdk_version_match(version_string)) {
241243
popart::logging::err("The sdk version of cache does not match {}",
242244
version_string);

ODLA/platforms/odla_popart/popart_config.cc

+26-13
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
#include "json.hpp"
2828

2929
PopartConfig* PopartConfig::instance_ = new PopartConfig();
30+
std::vector<std::string> PopartConfig::mode = {"unknown", "pipeline",
31+
"parallel", "sequence"};
32+
33+
const char* bool_to_str(const bool& value) { return value ? "true" : "false"; }
3034

3135
const std::string& get_config_path_from_cache_file(
3236
const std::string& cache_path) {
@@ -43,6 +47,7 @@ const std::string& get_config_path_from_cache_file(
4347

4448
void PopartConfig::use_default() {
4549
amp_ = 0.6;
50+
sdk_version_ = popart::core::packageHash();
4651
version_ = "1.0.0";
4752
batches_per_step_ = 1;
4853
ipu_num_ = 1;
@@ -54,19 +59,27 @@ void PopartConfig::use_default() {
5459
queue_type_ = "LockFreeQueue";
5560
queue_capacity_ = 1024 * 1024;
5661
debug_ = false;
57-
default_config_string_ =
62+
char config_base[] =
5863
"{\n\
59-
\"version\":\"1.0.0\",\n\
60-
\"amp\":0.6,\n\
61-
\"batches_per_step\":1,\n\
62-
\"execution_mode\":\"sequence\",\n\
63-
\"ipu_num\":1,\n\
64-
\"load_onnx\":false, \n\
65-
\"load_onnx_path\":\"test-load-time.onnx\",\n\
66-
\"queue_type\":\"LockFreeQueue\",\n\
67-
\"queue_capacity\":1048576,\n\
68-
\"debug\": false\n\
64+
\"sdk_version\":\"%s\",\n\
65+
\"version\":\"%s\",\n\
66+
\"amp\":%f,\n\
67+
\"batches_per_step\":%d,\n\
68+
\"execution_mode\":\"%s\",\n\
69+
\"ipu_num\":%d,\n\
70+
\"load_onnx\":%s, \n\
71+
\"load_onnx_path\":\"%s\",\n\
72+
\"queue_type\":\"%s\",\n\
73+
\"queue_capacity\":%d,\n\
74+
\"debug\":%s\n\
6975
}\n";
76+
char raw_default_config[1024] = {0};
77+
snprintf(raw_default_config, 1024, config_base, sdk_version_.c_str(),
78+
version_.c_str(), amp_, batches_per_step_,
79+
PopartConfig::mode[(int)execution_mode_].c_str(), ipu_num_,
80+
bool_to_str(load_onnx_), load_onnx_path_.c_str(),
81+
queue_type_.c_str(), queue_capacity_, bool_to_str(debug_));
82+
default_config_string_.assign(raw_default_config);
7083
}
7184

7285
odla_status PopartConfig::load_config(const char* env_file_path) {
@@ -196,12 +209,12 @@ odla_status PopartConfig::load_from_file(const std::string& file_path) {
196209
void PopartConfig::print() {
197210
std::string line(80, '=');
198211
popart::logging::info(line);
212+
popart::logging::info("sdk_version: {}", sdk_version_);
199213
popart::logging::info("version: {}", version_);
200214
popart::logging::info("amp: {}", amp_);
201215
popart::logging::info("batch_per_step: {}", batches_per_step_);
202-
std::string mode[] = {"UNKNOWN", "PIPELINE", "PARALLEL", "SEQUENCE"};
203216
popart::logging::info("execution_mode: {}",
204-
mode[(long unsigned int)execution_mode_]);
217+
PopartConfig::mode[(long unsigned int)execution_mode_]);
205218
popart::logging::info("ipu_num: {}", ipu_num_);
206219
std::string bool_value[] = {"false", "true"};
207220
popart::logging::info("load_onnx: {}",

ODLA/platforms/odla_popart/popart_config.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <iostream>
2424
#include <map>
2525
#include <mutex>
26+
#include <popart/version.hpp>
2627
#include <regex>
2728
#include <string>
2829
#include <vector>
@@ -61,6 +62,7 @@ class PopartConfig {
6162
std::string sdk_version_; // version of the sdk
6263
int batches_per_step_; // Batch per step for PIPELINE & PARALLEL execution
6364
// mode
65+
static std::vector<std::string> mode; // string value of execution mode
6466
ExecutionMode
6567
execution_mode_; // The execution mode {PIPELINE, PARALLEL, SEQUENCE}
6668
bool load_onnx_; // Whether load onnx model to run instead of the model
@@ -118,7 +120,6 @@ class PopartConfig {
118120
cache_fs->clear();
119121
}
120122
pipeline_setting_.clear();
121-
load_or_save_cache_ = false;
122123
sdk_version_ = "NA";
123124
}
124125
}

0 commit comments

Comments
 (0)