Skip to content

Commit

Permalink
bugfix for the type hints checking when args were specified in yaml f…
Browse files Browse the repository at this point in the history
…ile (#302)

* bugfix for the type hints checking when args specified in yaml file

* bugfix for the type hints checking when args specified in yaml file

* bugfix for the type hints checking when args specified in yaml file

* bugfix for the type hints checking when args specified in yaml file
  • Loading branch information
yxdyc authored May 6, 2024
1 parent 449475b commit a714f64
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
45 changes: 40 additions & 5 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import copy
import os
import shutil
import time
from argparse import ArgumentError
from argparse import ArgumentError, Namespace
from typing import Dict, List, Tuple, Union

from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace,
namespace_to_dict)
from jsonargparse.typehints import ActionTypeHint
from jsonargparse.typing import ClosedUnitInterval, NonNegativeInt, PositiveInt
from loguru import logger

Expand Down Expand Up @@ -370,8 +372,8 @@ def init_setup_from_cfg(cfg):
2. update cache directory
3. update checkpoint and `temp_dir` of tempfile
:param cfg: a original cfg
:param cfg: a updated cfg
:param cfg: an original cfg
:param cfg: an updated cfg
"""

cfg.export_path = os.path.abspath(cfg.export_path)
Expand Down Expand Up @@ -552,16 +554,16 @@ def update_op_process(cfg, parser):
# e.g.
# `python demo.py --config demo.yaml
# --language_id_score_filter.lang en`
temp_cfg = cfg
for i, op_in_process in enumerate(cfg.process):
op_in_process_name = list(op_in_process.keys())[0]

temp_cfg = cfg
if op_in_process_name not in option_in_commands:

# update op params to temp cfg if set
if op_in_process[op_in_process_name]:
temp_cfg = parser.merge_config(
dict_to_namespace(op_in_process), cfg)
dict_to_namespace(op_in_process), temp_cfg)
else:

# args in the command line override the ones in `cfg.process`
Expand All @@ -584,9 +586,42 @@ def update_op_process(cfg, parser):
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}

# check the op params via type hint
temp_parser = copy.deepcopy(parser)
recognized_args = set([
action.dest for action in parser._actions
if hasattr(action, 'dest') and isinstance(action, ActionTypeHint)
])

temp_args = namespace_to_arg_list(temp_cfg,
includes=recognized_args,
excludes=['config'])
temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args
temp_parser.parse_args(temp_args)
return cfg


def namespace_to_arg_list(namespace, prefix='', includes=None, excludes=None):
arg_list = []

for key, value in vars(namespace).items():

if issubclass(type(value), Namespace):
nested_args = namespace_to_arg_list(value, f'{prefix}{key}.')
arg_list.extend(nested_args)
elif value is not None:
concat_key = f'{prefix}{key}'
if includes is not None and concat_key not in includes:
continue
if excludes is not None and concat_key in excludes:
continue
arg_list.append(f'--{concat_key}')
arg_list.append(f'{value}')

return arg_list


def config_backup(cfg):
cfg_path = cfg.config[0].absolute
work_dir = cfg.work_dir
Expand Down
14 changes: 11 additions & 3 deletions data_juicer/utils/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import os
import sys
from io import StringIO

from loguru import logger
from loguru._file_sink import FileSink
Expand Down Expand Up @@ -52,12 +53,14 @@ def __init__(self, level='INFO', caller_names=('datasets', 'logging')):
Default value: (apex, pycocotools).
"""
self.level = level
self.linebuf = ''
self.caller_names = caller_names
self.buffer = StringIO()
self.BUFFER_SIZE = 1024 * 1024

def write(self, buf):
full_name = get_caller_name(depth=1)
module_name = full_name.rsplit('.', maxsplit=-1)[0]
self.buffer.write(buf)
if module_name in self.caller_names:
for line in buf.rstrip().splitlines():
# use caller level log
Expand All @@ -66,8 +69,13 @@ def write(self, buf):
# sys.__stdout__.write(buf)
logger.opt(raw=True).info(buf)

self.buffer.truncate(self.BUFFER_SIZE)

def getvalue(self):
return self.buffer.getvalue()

def flush(self):
pass
self.buffer.flush()


def redirect_sys_output(log_level='INFO'):
Expand All @@ -76,7 +84,7 @@ def redirect_sys_output(log_level='INFO'):
:param log_level: log level string of loguru. Default value: "INFO".
"""
redirect_logger = StreamToLoguru(log_level)
redirect_logger = StreamToLoguru(level=log_level)
sys.stderr = redirect_logger
sys.stdout = redirect_logger

Expand Down
19 changes: 19 additions & 0 deletions tests/config/demo_4_test_bad_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Process config example for Arxiv dataset

# global parameters
project_name: 'test_demo'
dataset_path: './demos/data/demo-dataset.jsonl' # path to your dataset directory or file
np: 4 # number of subprocess to process your dataset

export_path: './outputs/demo/demo-processed.parquet'

# process schedule
# a list of several process operators with their arguments
process:
- whitespace_normalization_mapper:
- language_id_score_filter:
lang: 'zh'
min_score: 1.1 # !! a bad value !!
- document_deduplicator: # deduplicate text samples using md5 hashing exact matching method
lowercase: false # whether to convert text to lower case
ignore_non_character: false
33 changes: 32 additions & 1 deletion tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import unittest
from contextlib import redirect_stdout
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO

from jsonargparse import Namespace
Expand All @@ -12,6 +12,9 @@
test_yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'demo_4_test.yaml')

test_bad_yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'demo_4_test_bad_val.yaml')


class ConfigTest(DataJuicerTestCaseBase):

Expand Down Expand Up @@ -70,6 +73,34 @@ def test_yaml_cfg_file(self):
_, op_from_cfg = load_ops(cfg.process)
self.assertTrue(len(op_from_cfg) == 3)

def test_val_range_check_cmd(self):
out = StringIO()
err_msg_head = ("language_id_score_filter.min_score")
err_msg = ("Not of type ClosedUnitInterval: 1.1 does not conform to "
"restriction v>=0 and v<=1")
with redirect_stdout(out), redirect_stderr(out):
with self.assertRaises(SystemExit) as cm:
init_configs(
args=f'--config {test_yaml_path} '
'--language_id_score_filter.min_score 1.1'.split())
self.assertEqual(cm.exception.code, 2)
out_str = out.getvalue()
self.assertIn(err_msg_head, out_str)
self.assertIn(err_msg, out_str)

def test_val_range_check_yaml(self):
out = StringIO()
err_msg_head = ("language_id_score_filter.min_score")
err_msg = ("Not of type ClosedUnitInterval: 1.1 does not conform to "
"restriction v>=0 and v<=1")
with redirect_stdout(out), redirect_stderr(out):
with self.assertRaises(SystemExit) as cm:
init_configs(args=f'--config {test_bad_yaml_path}'.split())
self.assertEqual(cm.exception.code, 2)
out_str = out.getvalue()
self.assertIn(err_msg_head, out_str)
self.assertIn(err_msg, out_str)

def test_mixture_cfg(self):
out = StringIO()
with redirect_stdout(out):
Expand Down

0 comments on commit a714f64

Please sign in to comment.