Skip to content

Commit 967a111

Browse files
committed
gpu index.
1 parent 6f5e458 commit 967a111

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/nn/nn-vulkan.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,10 @@ NnVulkanDevice::NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeCo
301301
context.instance = vk::createInstance(instanceCreateInfo);
302302

303303
auto physicalDevices = context.instance.enumeratePhysicalDevices();
304-
context.physicalDevice = physicalDevices.front();
304+
const NnSize nDevices = physicalDevices.size();
305+
if (gpuIndex >= nDevices)
306+
throw std::runtime_error("Invalid GPU index, found " + std::to_string(nDevices) + " GPUs");
307+
context.physicalDevice = physicalDevices[gpuIndex];
305308

306309
vk::PhysicalDeviceProperties deviceProps = context.physicalDevice.getProperties();
307310
printf("🌋 Device: %s\n", (char*)deviceProps.deviceName);

0 commit comments

Comments
 (0)