Skip to content

Commit a015628

Browse files
committed
fix: shaders.
1 parent e2368dc commit a015628

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/nn/vulkan/cast-forward-f32-f32.comp

-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@ layout(binding = 0) readonly buffer inputBuffer { float x[]; };
1515
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1616
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1717

18-
shared uint sharedPosition;
1918
shared BatchInfo sharedInfo;
2019

2120
void main() {
2221
const uint threadIndex = gl_LocalInvocationID.x;
2322
const uint batchIndex = gl_GlobalInvocationID.y;
2423

2524
if (threadIndex == 0) {
26-
sharedPosition = uint(x[batchIndex]);
2725
sharedInfo = infos[batchIndex];
2826
}
2927

src/nn/vulkan/mul-forward-f32-f32.comp

+5-4
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ void main() {
3131
memoryBarrierShared();
3232
barrier();
3333

34-
const uint start = sharedInfo.inputOffset + threadIndex;
35-
const uint end = sharedInfo.inputOffset + sharedInfo.inputSizeX;
34+
const uint inputSizeX = sharedInfo.inputSizeX;
35+
const uint xyOffset = sharedInfo.inputOffset;
36+
const uint mOffset = inputSizeX * batchIndex;
3637

37-
for (uint i = threadIndex; i < end; i += N_THREADS) {
38-
y[i] = x[i] * m[i];
38+
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
39+
y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i];
3940
}
4041
}

0 commit comments

Comments
 (0)