|
35 | 35 | _odla_computation* _odla_computation::instance_ = nullptr;
|
36 | 36 | std::mutex _odla_computation::comp_mutex_;
|
37 | 37 |
|
38 |
| -void compile_and_export_cache(std::string catch_file_name, |
39 |
| - std::string config_file_name) { |
40 |
| - std::fstream catch_fs(catch_file_name, |
41 |
| - std::ios_base::out | std::ifstream::binary); |
42 |
| - std::fstream config_fs; |
43 |
| - std::string config_string; |
44 |
| - if (config_file_name.size() > 0) { |
45 |
| - config_fs.open(config_file_name, std::ios_base::in | std::ifstream::binary); |
46 |
| - if (!config_fs.is_open()) { |
47 |
| - popart::logging::warn( |
48 |
| - "invalid config file name:[ {} ] will use default config", |
49 |
| - config_file_name); |
50 |
| - config_string = PopartConfig::instance()->get_default_config_string(); |
51 |
| - } |
52 |
| - std::ostringstream config_ss; |
53 |
| - config_ss << config_fs.rdbuf(); |
54 |
| - config_string = config_ss.str(); |
55 |
| - } else { |
56 |
| - config_string = PopartConfig::instance()->get_default_config_string(); |
57 |
| - } |
58 |
| - |
59 |
| - int config_size = config_string.size(); |
60 |
| - catch_fs.write((char*)&config_size, sizeof(config_string.size())); |
61 |
| - catch_fs.write(config_string.c_str(), config_string.size()); |
62 |
| - |
63 |
| - _odla_computation::instance()->session->compileAndExport(catch_fs.flush()); |
64 |
| - catch_fs.flush(); |
65 |
| - catch_fs.close(); |
66 |
| -} |
67 |
| - |
68 | 38 | void compute_loop(odla_computation comp) {
|
69 | 39 | // setup the stepio with allbacks
|
70 | 40 | popart::StepIOCallback stepio(input_callback, input_complete_callback,
|
@@ -119,7 +89,59 @@ void compute_loop(odla_computation comp) {
|
119 | 89 | comp->thread_done();
|
120 | 90 | }
|
121 | 91 |
|
122 |
| -void _odla_computation::init() { |
| 92 | +odla_status _odla_computation::compile_and_export() { |
| 93 | + popart::logging::warn("Start compile and export"); |
| 94 | + const std::string& cache_file_name = |
| 95 | + PopartConfig::instance()->get_cache_path(); |
| 96 | + std::string file_suffix(".popart"); |
| 97 | + int file_prefix = cache_file_name.rfind(file_suffix); |
| 98 | + if (file_prefix == std::string::npos || |
| 99 | + file_prefix + file_suffix.size() < cache_file_name.size()) { |
| 100 | + popart::logging::err("Bad cache file name"); |
| 101 | + return ODLA_FAILURE; |
| 102 | + } |
| 103 | + if (file_prefix == std::string::npos) { |
| 104 | + file_prefix = cache_file_name.size() - 1; |
| 105 | + } |
| 106 | + std::string config_file_name(cache_file_name.substr(0, file_prefix) + |
| 107 | + ".json"); |
| 108 | + std::fstream cache_fs(cache_file_name, |
| 109 | + std::ios_base::out | std::ifstream::binary); |
| 110 | + if (!cache_fs.is_open()) { |
| 111 | + popart::logging::err("Open or create cache file falied"); |
| 112 | + return ODLA_FAILURE; |
| 113 | + } |
| 114 | + std::fstream config_fs; |
| 115 | + std::string config_string; |
| 116 | + if (config_file_name.size() > 0) { |
| 117 | + config_fs.open(config_file_name, std::ios_base::in | std::ifstream::binary); |
| 118 | + if (!config_fs.is_open()) { |
| 119 | + popart::logging::warn( |
| 120 | + "invalid config file name:[ {} ] will use default config", |
| 121 | + config_file_name); |
| 122 | + PopartConfig::instance()->use_default(); |
| 123 | + config_string = PopartConfig::instance()->get_default_config_string(); |
| 124 | + } else { |
| 125 | + std::ostringstream config_ss; |
| 126 | + config_ss << config_fs.rdbuf(); |
| 127 | + config_string = config_ss.str(); |
| 128 | + } |
| 129 | + } else { |
| 130 | + config_string = PopartConfig::instance()->get_default_config_string(); |
| 131 | + } |
| 132 | + |
| 133 | + int config_size = config_string.size(); |
| 134 | + cache_fs.write((char*)&config_size, sizeof(config_size)); |
| 135 | + cache_fs.write(config_string.c_str(), config_string.size()); |
| 136 | + |
| 137 | + _odla_computation::instance()->session->compileAndExport(cache_fs.flush()); |
| 138 | + |
| 139 | + cache_fs.flush(); |
| 140 | + cache_fs.close(); |
| 141 | + config_fs.close(); |
| 142 | +} |
| 143 | + |
| 144 | +void _odla_computation::init(bool is_compile) { |
123 | 145 | if (!session) {
|
124 | 146 | std::lock_guard<std::mutex> guard(init_mutex_);
|
125 | 147 | if (!session) {
|
@@ -167,21 +189,31 @@ void _odla_computation::init() {
|
167 | 189 | auto new_session = popart::InferenceSession::createFromOnnxModel(
|
168 | 190 | proto, data_flow, device, popart::InputShapeInfo(), session_opts_);
|
169 | 191 |
|
170 |
| - if (PopartConfig::instance()->load_cache()) { |
171 |
| - popart::logging::info("Load cachefile from existing stream"); |
172 |
| - auto cache_fs = PopartConfig::instance()->get_cache_fs(); |
173 |
| - new_session->loadExecutableFromStream(*(cache_fs.get())); |
174 |
| - } |
175 |
| - new_session->prepareDevice(); |
176 |
| - new_session->setRandomSeed(0); // Init seed |
177 |
| - new_session->weightsFromHost(); // Copy weights from host to IPU |
178 |
| - // If in parallel mode, start the thread |
179 |
| - ExecutionMode mode = PopartConfig::instance()->execution_mode(); |
180 |
| - if (PIPELINE == mode || PARALLEL == mode) { |
181 |
| - std::thread parallel_thread(compute_loop, this); |
182 |
| - thread_state_ = RUNNING; |
183 |
| - popart::logging::warn("Parallel loop has been started"); |
184 |
| - parallel_thread.detach(); |
| 192 | + if (!is_compile) { |
| 193 | + if (PopartConfig::instance()->load_cache()) { |
| 194 | + popart::logging::info("Load cachefile from existing stream"); |
| 195 | + auto cache_fs = PopartConfig::instance()->get_cache_fs(); |
| 196 | + if (cache_fs->is_open()) { |
| 197 | + try { |
| 198 | + new_session->loadExecutableFromStream(*(cache_fs.get())); |
| 199 | + } catch (std::exception& e) { |
| 200 | + popart::logging::err("bad cache file, will compile the graph"); |
| 201 | + } |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + new_session->prepareDevice(); |
| 206 | + new_session->setRandomSeed(0); // Init seed |
| 207 | + new_session->weightsFromHost(); // Copy weights from host to IPU |
| 208 | + |
| 209 | + // If in parallel mode, start the thread |
| 210 | + ExecutionMode mode = PopartConfig::instance()->execution_mode(); |
| 211 | + if (PIPELINE == mode || PARALLEL == mode) { |
| 212 | + std::thread parallel_thread(compute_loop, this); |
| 213 | + thread_state_ = RUNNING; |
| 214 | + popart::logging::warn("Parallel loop has been started"); |
| 215 | + parallel_thread.detach(); |
| 216 | + } |
185 | 217 | }
|
186 | 218 | session =
|
187 | 219 | std::move(new_session); // set session after all initialization done.
|
|
0 commit comments