@@ -2040,6 +2040,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2040
2040
std::cerr << " Done!" << std::endl;
2041
2041
}
2042
2042
2043
+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2044
+
2043
2045
static vk_device ggml_vk_get_device (size_t idx) {
2044
2046
VK_LOG_DEBUG (" ggml_vk_get_device(" << idx << " )" );
2045
2047
@@ -2175,9 +2177,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2175
2177
2176
2178
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2177
2179
2178
- if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2179
- // Intel drivers don't support coopmat properly yet
2180
- // Only RADV supports coopmat properly on AMD
2180
+ if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
2181
2181
device->coopmat_support = false ;
2182
2182
}
2183
2183
@@ -2515,7 +2515,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2515
2515
return vk_instance.devices [idx];
2516
2516
}
2517
2517
2518
-
2519
2518
static void ggml_vk_print_gpu_info (size_t idx) {
2520
2519
GGML_ASSERT (idx < vk_instance.device_indices .size ());
2521
2520
size_t dev_num = vk_instance.device_indices [idx];
@@ -2565,9 +2564,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2565
2564
}
2566
2565
}
2567
2566
2568
- if (props2.properties .vendorID == VK_VENDOR_ID_INTEL || (props2.properties .vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2569
- // Intel drivers don't support coopmat properly yet
2570
- // Only RADV supports coopmat properly on AMD
2567
+ if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
2571
2568
coopmat_support = false ;
2572
2569
}
2573
2570
@@ -8088,6 +8085,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8088
8085
UNUSED (instance_extensions);
8089
8086
}
8090
8087
8088
+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8089
+ switch (props.vendorID ) {
8090
+ case VK_VENDOR_ID_INTEL:
8091
+ // Intel drivers don't support coopmat properly yet
8092
+ return false ;
8093
+ case VK_VENDOR_ID_AMD:
8094
+ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8095
+ // Workaround for AMD proprietary driver reporting support on all GPUs
8096
+ const std::string name = props.deviceName ;
8097
+ return name.rfind (" AMD Radeon RX 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) RX 7" , 0 ) == 0 || // RDNA 3 consumer GPUs
8098
+ name.rfind (" AMD Radeon PRO W7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) PRO W7" , 0 ) == 0 || // RDNA 3 workstation GPUs
8099
+ name.rfind (" AMD Radeon 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) 7" , 0 ) == 0 ; // RDNA 3 APUs
8100
+ }
8101
+ return true ;
8102
+ default :
8103
+ return true ;
8104
+ }
8105
+ }
8106
+
8091
8107
// checks
8092
8108
8093
8109
#ifdef GGML_VULKAN_CHECK_RESULTS
0 commit comments