Skip to content

Commit

Permalink
fixes #217. Also adds a repeat of progress to address #215 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
o-smirnov committed Feb 2, 2024
1 parent 3b584a7 commit e5e8557
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
3 changes: 2 additions & 1 deletion stimela/backends/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def get_executable(self):
def run_command_wrapper(self, args: List[str], fqname: Optional[str]=None, log: Optional[logging.Logger]=None) -> List[str]:
output_args = [self.get_executable()]

# reverse fqname to make job name (more informative that way)
if fqname is not None:
output_args += ["-J", fqname]
output_args += ["-J", '.'.join(fqname.split('.')[::-1])]

# add all base options that have been specified
for name, value in self.srun_opts.items():
Expand Down
30 changes: 21 additions & 9 deletions stimela/kitchen/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def finalize(self, config=None, log=None, name=None, fqname=None, backend=None,
self.logopts = config.opts.log.copy()

# update file logger
logsubst = SubstitutionNS(config=config, info=dict(fqname=fqname))
logsubst = SubstitutionNS(config=config, info=dict(fqname=fqname, taskname=fqname))
stimelogging.update_file_logger(log, self.logopts, nesting=nesting, subst=logsubst, location=[self.fqname])

# call Cargo's finalize method
Expand Down Expand Up @@ -744,7 +744,7 @@ def prevalidate(self, params: Dict[str, Any], subst: Optional[SubstitutionNS]=No
subst_outer = subst # outer dictionary is used to prevalidate our parameters

subst = SubstitutionNS()
info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='')
info = SubstitutionNS(fqname=self.fqname, taskname=self.fqname, label='', label_parts=[], suffix='')
# mutable=False means these sub-namespaces are not subject to {}-substitutions
subst._add_('info', info.copy(), nosubst=True)
subst._add_('config', self.config, nosubst=True)
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def _update_aliases(self, name: str, value: Any):
alias.step.update_parameter(alias.param, value)


def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, subprocess=False, raise_exc=True):
def _iterate_loop_worker(self, params, subst, backend, count, iter_var, subprocess=False, raise_exc=True):
""""
Needed for concurrency
"""
Expand All @@ -1046,6 +1046,7 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
task_stats.add_subprocess_id(count)
task_stats.destroy_progress_bar()
subst.info.subprocess = task_stats.get_subprocess_id()
taskname = subst.info.taskname
outputs = {}
exception = tb = None
task_attrs, task_kwattrs = (), {}
Expand Down Expand Up @@ -1075,6 +1076,8 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
if status is None:
status = "{index1}/{total}".format(**status_dict)
task_stats.declare_subtask_status(status)
taskname = f"{taskname}.{count}"
subst.info.taskname = taskname
# task_stats.declare_subtask_attributes(count)
# task_attrs = (count,)
context = task_stats.declare_subtask(f"({count})")
Expand All @@ -1085,16 +1088,20 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
for label, step in self.steps.items():
# update step info
self._prep_step(label, step, subst)
subst.info.taskname = f"{taskname}.{label}"
# reevaluate recipe level assignments (info.fqname etc. have changed)
self.update_assignments(subst, params=params)
# evaluate step-level assignments
self.update_assignments(subst, whose=step, params=params)
# step logger may have changed
stimelogging.update_file_logger(step.log, step.logopts, nesting=step.nesting, subst=subst, location=[step.fqname])
# set our info back temporarily to update log assignments
info_step = subst.info
subst.info = info.copy()
subst.info = info_step

## OMS: note to self, I had this here but not sure why. Seems like a no-op. Something with logname fiddling.
## Leave as a puzzle to future self for a bit. Remove info from args.
# info_step = subst.info
# subst.info = info.copy()
# subst.info = info_step

if step.skip is True:
self.log.debug(f"step '{label}' will be explicitly skipped")
Expand Down Expand Up @@ -1137,7 +1144,7 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
# else will be returned
exception = exc
tb = FormattedTraceback(sys.exc_info()[2])

return task_attrs, task_kwattrs, task_stats.collect_stats(), outputs, exception, tb

def build(self, backend={}, rebuild=False, build_skips=False, log: Optional[logging.Logger] = None):
Expand Down Expand Up @@ -1172,8 +1179,11 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
subst_outer = subst
if subst is None:
subst = SubstitutionNS()
taskname = self.name
else:
taskname = subst.info.taskname

info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='')
info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='', taskname=taskname)
# nosubst=True means these sub-namespaces are not subject to {}-substitutions
subst._add_('info', info.copy(), nosubst=True)
subst._add_('config', self.config, nosubst=True)
Expand Down Expand Up @@ -1224,7 +1234,7 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
# form list of arguments for each invocation of the loop worker
loop_worker_args = []
for count, iter_var in enumerate(self._for_loop_values):
loop_worker_args.append((params, info, subst, backend, count, iter_var))
loop_worker_args.append((params, subst, backend, count, iter_var))

# if scatter is enabled, use a process pool
if self._for_loop_scatter:
Expand Down Expand Up @@ -1264,6 +1274,8 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
if errors:
pool.shutdown()
raise StimelaRuntimeError(f"{nfail}/{nloop} jobs have failed", errors)
# drop a rendering of the progress bar onto the console, to overwrite previous garbage if it's there
task_stats.restate_progress()
# else just iterate directly
else:
for args in loop_worker_args:
Expand Down
2 changes: 1 addition & 1 deletion stimela/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def cli(config_files=[], config_dotlist=[], include=[], backend=None,
stimela.CONFIG.opts.log.level = "DEBUG"
# setup file logging
subst = OmegaConf.create(dict(
info=OmegaConf.create(dict(fqname='stimela')),
info=OmegaConf.create(dict(fqname='stimela', taskname='stimela')),
config=stimela.CONFIG))
stimelogging.update_file_logger(log, stimela.CONFIG.opts.log, nesting=-1, subst=subst)

Expand Down
5 changes: 5 additions & 0 deletions stimela/task_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def destroy_progress_bar():
progress_bar.__exit__(None, None, None)
progress_bar = None

def restate_progress():
"""Renders a snapshot of the progress bar onto the console"""
if progress_bar is not None:
progress_console.print(progress_bar.get_renderable())

@contextlib.contextmanager
def declare_subtask(subtask_name, status_reporter=None, hide_local_metrics=False):
task_names = []
Expand Down

0 comments on commit e5e8557

Please sign in to comment.