Skip to content

Commit 1454e1f

Browse files
authored
[dataset] add ms_logger_context (#4428)
1 parent 09255ce commit 1454e1f

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

swift/hub/hub.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import logging
23
import os
34
import tempfile
45
from contextlib import contextmanager
@@ -11,11 +12,11 @@
1112
from huggingface_hub.hf_api import api, future_compatible
1213
from requests.exceptions import HTTPError
1314
from transformers import trainer
14-
from transformers.utils import logging, strtobool
15+
from transformers.utils import strtobool
1516

16-
from swift.utils.env import use_hf_hub
17+
from swift.utils import get_logger, ms_logger_context, use_hf_hub
1718

18-
logger = logging.get_logger(__name__)
19+
logger = get_logger()
1920

2021

2122
class HubOperation:
@@ -287,15 +288,15 @@ def load_dataset(cls,
287288
cls.try_login(token)
288289
if revision is None or revision == 'main':
289290
revision = 'master'
290-
291-
return MsDataset.load(
292-
dataset_id,
293-
subset_name=subset_name,
294-
split=split,
295-
version=revision,
296-
download_mode=download_mode,
297-
use_streaming=streaming,
298-
)
291+
with ms_logger_context(logging.ERROR):
292+
return MsDataset.load(
293+
dataset_id,
294+
subset_name=subset_name,
295+
split=split,
296+
version=revision,
297+
download_mode=download_mode,
298+
use_streaming=streaming,
299+
)
299300

300301
@classmethod
301302
def download_model(cls,

swift/trainers/mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
# Part of the implementation is borrowed from huggingface/transformers.
33
import inspect
4+
import logging
45
import os
56
import shutil
67
import time
@@ -33,7 +34,7 @@
3334
from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
3435
from swift.plugin import MeanMetric, compute_acc, extra_tuners
3536
from swift.tuners import SwiftModel
36-
from swift.utils import get_logger, is_mp_ddp, seed_worker, use_torchacc
37+
from swift.utils import get_logger, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc
3738
from swift.utils.torchacc_utils import ta_trim_graph
3839
from ..utils.torch_utils import get_device_count
3940
from .arguments import TrainingArguments
@@ -68,8 +69,7 @@ def __init__(self,
6869
logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.')
6970

7071
if args.check_model and hasattr(model, 'model_dir'):
71-
from swift.utils.logger import ms_logger_ignore_error
72-
with ms_logger_ignore_error():
72+
with ms_logger_context(logging.CRITICAL):
7373
check_local_model_is_latest(
7474
model.model_dir, user_agent={
7575
'invoked_by': 'local_trainer',

swift/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available,
88
is_xtuner_available)
99
from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl
10-
from .logger import get_logger
10+
from .logger import get_logger, ms_logger_context
1111
from .np_utils import get_seed, stat_array, transform_jsonl_to_df
1212
from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
1313
from .torch_utils import (Serializer, activate_parameters, check_shared_disk, find_all_linears, find_embedding,

swift/utils/logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None,
111111

112112

113113
@contextmanager
114-
def ms_logger_ignore_error():
114+
def ms_logger_context(log_leval):
115115
ms_logger = get_ms_logger()
116116
origin_log_level = ms_logger.level
117-
ms_logger.setLevel(logging.CRITICAL)
117+
ms_logger.setLevel(log_leval)
118118
try:
119119
yield
120120
finally:

0 commit comments

Comments
 (0)