Skip to content

Commit 70f07c9

Browse files
committed
extend mergekit to work on xpu
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
1 parent 378c355 commit 70f07c9

File tree

11 files changed

+60
-33
lines changed

11 files changed

+60
-33
lines changed

mergekit/evo/actors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def evaluate_genotype(
9090
genotype: torch.Tensor,
9191
) -> dict:
9292
gc.collect()
93-
torch.cuda.empty_cache()
93+
empty_cache()
9494
LOG.info("Merging model")
9595
merged_path = merge_model(
9696
genotype, self.genome, self.model_storage_path, self.merge_options
@@ -190,7 +190,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
190190
**model_kwargs,
191191
)
192192
.bfloat16()
193-
.cuda()
193+
.to(self.merge_options.device)
194194
.eval()
195195
.requires_grad_(False)
196196
)
@@ -227,15 +227,15 @@ def _maybe_init_model(self, config: MergeConfiguration):
227227
LOG.warning(f"Clipping sequence length to {max_model_len}")
228228

229229
mem_util = (
230-
0.7 if self.merge_options.cuda else 0.9
231-
) # reduce memory usage if we're also using cuda for the merge
230+
0.7 if self.merge_options.device in ["cuda", "xpu"] else 0.9
231+
) # reduce memory usage if we're also using accelerator for the merge
232232
self.model = lm_eval.models.vllm_causallms.VLLM(
233233
pretrained=tempdir,
234234
batch_size=self.batch_size or "auto",
235235
max_model_len=max_model_len,
236236
gpu_memory_utilization=mem_util,
237237
dtype="bfloat16",
238-
device="cuda",
238+
device=self.merge_options.device,
239239
trust_remote_code=self.merge_options.trust_remote_code,
240240
)
241241
else:
@@ -294,8 +294,8 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
294294

295295
executor = Executor(
296296
tasks,
297-
math_device="cuda" if self.merge_options.cuda else "cpu",
298-
storage_device="cuda" if self.merge_options.cuda else "cpu",
297+
math_device=self.merge_options.device if self.merge_options.device in ["cuda", "xpu"] else "cpu",
298+
storage_device=self.merge_options.device if self.merge_options.device in ["cuda", "xpu"] else "cpu",
299299
)
300300
for tensor_task, value in executor.run(quiet=True):
301301
assert isinstance(tensor_task, ReturnTensor)

mergekit/evo/strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self.config = config
3838
self.genome = genome
3939
self.merge_options = merge_options
40-
self.num_gpus = num_gpus or torch.cuda.device_count()
40+
self.num_gpus = num_gpus or getattr(torch, self.merge_options.device).device_count()
4141
self.batch_size = batch_size
4242
self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path)
4343
self.model_storage_path = model_storage_path
@@ -118,7 +118,7 @@ def __init__(
118118
self.genome = genome
119119
self.merge_options = merge_options
120120
self.vllm = vllm
121-
self.num_gpus = num_gpus or torch.cuda.device_count()
121+
self.num_gpus = num_gpus or getattr(torch, self.merge_options.device).device_count()
122122
self.input_queue = []
123123
self.batch_size = batch_size
124124
self.task_manager = task_manager

mergekit/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def _move_tensors(
529529
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
530530
) -> Any:
531531
if non_blocking is None:
532-
non_blocking = device.type == "cuda"
532+
non_blocking = device.type in ["cuda", "xpu"]
533533
if isinstance(value, torch.Tensor):
534534
if value.device == device:
535535
return value

mergekit/merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def run_merge(
7777
else:
7878
exec = Executor(
7979
targets=targets,
80-
math_device="cuda" if options.cuda else "cpu",
81-
storage_device="cuda" if options.low_cpu_memory else "cpu",
80+
math_device=options.device,
81+
storage_device=options.device if options.low_cpu_memory else "cpu",
8282
)
8383

8484
tokenizer = None

mergekit/multigpu_executor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ def __init__(
7575
self.results: Dict[TaskHandle, Any] = {}
7676
self.storage_device = storage_device
7777

78+
self.accelerator_type = getattr(torch, torch.acclerator.current_accelerator().type) if hasattr(torch, "accelerator") else "cuda"
79+
torch_accelerator_module = getattr(torch, self.accelerator_type)
7880
if num_gpus is None:
79-
num_gpus = torch.cuda.device_count()
80-
LOG.info(f"Using {num_gpus} GPUs for parallel execution")
81+
num_gpus = torch_accelerator_module.device_count()
82+
LOG.info(f"Using {num_gpus} {accelerator_type} for parallel execution")
8183

8284
self.universe = TaskUniverse(targets)
8385
self.targets = set([self.universe.get_handle(t) for t in targets])
@@ -309,12 +311,12 @@ def _assign_islands_to_gpus(
309311
continue
310312
# don't need to sort, inner executor will handle
311313
island_tasks = [TaskHandle(self.universe, idx) for idx in island]
312-
# assign to GPU with fewest tasks (load balancing)
314+
# assign to accelerator with fewest tasks (load balancing)
313315
device_idx = min(
314316
range(num_gpus),
315-
key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])),
317+
key=lambda i: len(assignments.get(torch.device(f"{self.accelerator_type}:{i}"), [])),
316318
)
317-
device = torch.device(f"cuda:{device_idx}")
319+
device = torch.device(f"{self.accelerator_type}:{device_idx}")
318320
assignments[device] = assignments.get(device, []) + island_tasks
319321
return assignments
320322

@@ -339,9 +341,10 @@ def _device_worker(
339341
quiet: Whether to suppress progress bar output
340342
"""
341343
LOG.debug(f"Device {device} starting")
344+
torch_accelerator_module = getattr(torch, self.accelerator_type)
342345
with torch.device(device):
343-
stream = torch.cuda.Stream(device=device)
344-
with torch.cuda.stream(stream):
346+
stream = torch.Stream(device=device) if self.accelerator_type == "xpu" else torch.cuda.Stream(device=device)
347+
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
345348
exec = Executor(
346349
targets=task_list,
347350
math_device=device,
@@ -358,5 +361,5 @@ def _device_worker(
358361
):
359362
result = None
360363
self.task_completion_queue.put((task_handle._index, result))
361-
torch.cuda.synchronize(device=device)
364+
torch_accelerator_module.synchronize(device=device)
362365
LOG.debug(f"Device {device} done")

mergekit/options.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class MergeOptions(BaseModel, frozen=True):
2121
lora_merge_cache: Optional[str] = None
2222
lora_merge_dtype: Optional[str] = None
2323
cuda: bool = False
24+
device: Optional[str] = "auto"
2425
low_cpu_memory: bool = False
2526
out_shard_size: int = parse_kmb("5B")
2627
copy_tokenizer: bool = True
@@ -62,14 +63,33 @@ def handle_gpu_rich(cls, value):
6263
value["multi_gpu"] = True
6364
return value
6465

66+
@model_validator(mode="before")
67+
def handle_device_setting(cls, value):
68+
if not isinstance(value, dict):
69+
return value
70+
71+
# Set device to "cuda" if cuda is True and device is still at default
72+
if value.get("cuda"):
73+
value["device"] = "cuda"
74+
75+
# Detect device automatically if `device` is set to "auto"
76+
if value.get("device") is None or value.get("device") == "auto":
77+
if torch.cuda.is_available():
78+
value["device"] = "cuda"
79+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
80+
value["device"] = "xpu"
81+
else:
82+
value["device"] = "cpu"
83+
return value
6584

6685
OPTION_HELP = {
6786
"allow_crimes": "Allow mixing architectures",
6887
"transformers_cache": "Override storage path for downloaded models",
6988
"lora_merge_cache": "Path to store merged LORA models",
7089
"lora_merge_dtype": "Override dtype when applying LoRAs",
7190
"cuda": "Perform matrix arithmetic on GPU",
72-
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM",
91+
"device": "Perform matrix arithmetic on specified device",
92+
"low_cpu_memory": "Store results and intermediate values on accelerator. Useful if VRAM > RAM",
7393
"out_shard_size": "Number of parameters per output shard [default: 5B]",
7494
"copy_tokenizer": "Copy a tokenizer to the output",
7595
"clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer",
@@ -79,11 +99,11 @@ def handle_gpu_rich(cls, value):
7999
"write_model_card": "Output README.md containing details of the merge",
80100
"safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.",
81101
"quiet": "Suppress progress bars and other non-essential output",
82-
"read_to_gpu": "Read model weights directly to GPU",
102+
"read_to_gpu": "Read model weights directly to accelerator",
83103
"multi_gpu": "Use multi-gpu parallel graph execution engine",
84104
"num_threads": "Number of threads to use for parallel CPU operations",
85105
"verbosity": "Verbose logging (repeat for more verbosity)",
86-
"gpu_rich": "Alias for --cuda --low-cpu-memory --read-to-gpu --multi-gpu",
106+
"gpu_rich": "Alias for --accelerator --cuda --low-cpu-memory --read-to-gpu --multi-gpu",
87107
}
88108

89109
OPTION_CATEGORIES = {
@@ -96,6 +116,7 @@ def handle_gpu_rich(cls, value):
96116
"safe_serialization": "Output Settings",
97117
"lazy_unpickle": "Performance",
98118
"cuda": "Performance",
119+
"device": "Performance",
99120
"low_cpu_memory": "Performance",
100121
"read_to_gpu": "Performance",
101122
"multi_gpu": "Performance",
@@ -127,8 +148,12 @@ def wrapper(*args, **kwargs):
127148
if field_name in kwargs:
128149
arg_dict[field_name] = kwargs.pop(field_name)
129150

130-
kwargs["merge_options"] = MergeOptions(**arg_dict)
131-
f(*args, **kwargs)
151+
try:
152+
kwargs["merge_options"] = MergeOptions(**arg_dict)
153+
except Exception as e:
154+
print(f"Error creating MergeOptions with args: {arg_dict}")
155+
raise
156+
return f(*args, **kwargs)
132157

133158
for field_name, info in reversed(MergeOptions.model_fields.items()):
134159
origin = typing.get_origin(info.annotation)

mergekit/plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def plan_tensor(
204204
gather_tensors = GatherTensors(
205205
weight_info=ImmutableMap(data=dict(zip(models, weights_in))),
206206
dtype=self.config.dtype,
207-
device="cuda" if self.options.read_to_gpu else None,
207+
device=self.options.device if self.options.read_to_gpu else None,
208208
)
209209

210210
tensor_input_task = gather_tensors

mergekit/scripts/extract_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def main(
151151
else:
152152
executor = Executor(
153153
tasks,
154-
math_device="cuda" if merge_options.cuda else "cpu",
155-
storage_device="cuda" if merge_options.low_cpu_memory else "cpu",
154+
math_device=merge_options.device,
155+
storage_device=merge_options.device if merge_options.low_cpu_memory else "cpu",
156156
)
157157

158158
module_real_ranks = {}

mergekit/scripts/merge_raw_pytorch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,8 @@ def main(
248248

249249
executor = Executor(
250250
tasks,
251-
math_device="cuda" if merge_options.cuda else "cpu",
252-
storage_device=(
253-
"cuda" if (merge_options.cuda and merge_options.low_cpu_memory) else "cpu"
251+
math_device=merge_options.device,
252+
storage_device=(merge_options.device if merge_options.low_cpu_memory) else "cpu"
254253
),
255254
)
256255
executor.execute()

mergekit/scripts/multimerge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main(
117117

118118
executor = Executor(
119119
tasks, math_device="cpu", storage_device="cpu"
120-
) # inner executors will handle cuda
120+
) # inner executors will handle accelerator
121121
executor.execute(desc="Merging models")
122122

123123

mergekit/scripts/tokensurgeon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(
8888
cache = LoaderCache()
8989
cache.setup(options=merge_options)
9090

91-
device = "cuda" if merge_options.cuda else "cpu"
91+
device = merge_options.device
9292

9393
arch_info, donor_cfg = validate_architecture(model, donor, merge_options)
9494
embed_info, lm_head_info = get_embedding_info(model, merge_options)

0 commit comments

Comments
 (0)