Skip to content

Commit

Permalink
set_device
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Feb 8, 2025
1 parent e859f7d commit 094b0e1
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions xuance/common/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def get_arguments(method, env, env_id, config_path=None, parser_args=None, is_te
if parser_args is not None:
configs = [recursive_dict_update(config_i, parser_args.__dict__) for config_i in configs]
args = [SN(**config_i) for config_i in configs]
for arg in args:
arg.device = set_device(arg.dl_toolbox, arg.device)
if is_test: # for test mode
for i_args in range(len(args)):
args[i_args].test_mode = int(is_test)
Expand Down Expand Up @@ -129,6 +131,7 @@ def get_arguments(method, env, env_id, config_path=None, parser_args=None, is_te
if not ('env_id' in configs.keys()):
configs['env_id'] = env_id
args = SN(**configs)
args.device = set_device(args.dl_toolbox, args.device)
if is_test:
args.test_mode = int(is_test)
args.parallels = 1
Expand Down Expand Up @@ -330,6 +333,52 @@ def space2shape(observation_space):
return observation_space.shape


def set_device(dl_toolbox: str, expected_device: str):
"""
Set the computing device for a given deep learning framework.
Args:
dl_toolbox (str): The deep learning framework to use.
Options: "torch", "tensorflow", "mindspore".
expected_device (str): The desired computing device.
Options: "cuda", "GPU", "gpu", "Ascend", "cpu", "CPU.
Returns:
str: The assigned computing device, which may differ from `expected_device`
if the requested device is unavailable.
"""
device = expected_device
if dl_toolbox == "torch":
if "cuda" in expected_device:
import torch
if not torch.cuda.is_available():
print("WARNING: CUDA for PyTorch is not available, set the device as 'cpu'.")
device = "cpu"
return device
if dl_toolbox == 'tensorflow':
if expected_device == "GPU" or expected_device == "gpu":
import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')) == 0:
print("WARNING: GPU for Tensorflow2 is not available, set the device as 'cpu'.")
device = "CPU"
return device
if dl_toolbox == 'mindspore':
import mindspore.context as context
if expected_device == "GPU":
context.set_context(device_target="GPU")
device_num = context.get_auto_parallel_context("device_num")
if device_num == 0:
print("WARNING: GPU for MindSpore is not available, set the device as 'CPU'.")
device = "CPU"
elif expected_device == "Ascend":
context.set_context(device_target="Ascend")
device_num = context.get_auto_parallel_context("device_num")
if device_num == 0:
print("WARNING: Ascend for MindSpore is not available, set the device as 'CPU'.")
device = "CPU"
return device


def discount_cumsum(x, discount=0.99):
"""Get a discounted cumulated summation.
Args:
Expand Down

0 comments on commit 094b0e1

Please sign in to comment.