Skip to content

Commit 3a893df

Browse files
committed
Avoid promition to Int32 in work_items functions
1 parent c0760cb commit 3a893df

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

lib/intrinsics/src/SPIRVIntrinsics.jl

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ import ExprTools
77

88
import SpecialFunctions
99

10+
# helper type for writing UInt32/Int32 literals
11+
# TODO: upstream this
12+
struct Literal{T} end
13+
Base.:(*)(x::Number, ::Type{Literal{T}}) where {T} = T(x)
14+
const i32 = Literal{Int32}
15+
const u32 = Literal{UInt32}
16+
1017
include("pointer.jl")
1118
include("utils.jl")
1219

lib/intrinsics/src/work_item.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ export get_work_dim,
1212

1313
@device_function get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int
1414

15-
@device_function get_global_size(dimindx::Integer=1) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx-1) % Int
16-
@device_function get_global_id(dimindx::Integer=1) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx-1) % Int + 1
15+
@device_function get_global_size(dimindx::Integer=1u32) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx-1u32) % Int
16+
@device_function get_global_id(dimindx::Integer=1u32) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx-1u32) % Int + 1
1717

18-
@device_function get_local_size(dimindx::Integer=1) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx-1) % Int
19-
@device_function get_enqueued_local_size(dimindx::Integer=1) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx-1) % Int
20-
@device_function get_local_id(dimindx::Integer=1) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx-1) % Int + 1
18+
@device_function get_local_size(dimindx::Integer=1u32) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx-1u32) % Int
19+
@device_function get_enqueued_local_size(dimindx::Integer=1u32) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx-1) % Int
20+
@device_function get_local_id(dimindx::Integer=1u32) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx-1u32) % Int + 1
2121

22-
@device_function get_num_groups(dimindx::Integer=1) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx-1) % Int
23-
@device_function get_group_id(dimindx::Integer=1) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx-1) % Int + 1
22+
@device_function get_num_groups(dimindx::Integer=1u32) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx-1u32) % Int
23+
@device_function get_group_id(dimindx::Integer=1u32) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx-1u32) % Int + 1
2424

25-
@device_function get_global_offset(dimindx::Integer=1) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx-1) % Int + 1
25+
@device_function get_global_offset(dimindx::Integer=1u32) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx-1u32) % Int + 1
2626

2727
@device_function get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1
2828
@device_function get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1

test/REQUIRE

Whitespace-only changes.

0 commit comments

Comments
 (0)