Skip to content

Commit

Permalink
refactor (#6)
Browse files Browse the repository at this point in the history
* Button component

* update layout

* Button: forward ref and generic props

* Canvas component

* responsiveness

* WebGPU support status

* Button: disabled style

* frame loop, FPS meter

* Stat components

* Renderer class

* vertices code alignment

* rename shader file to Renderer.wgsl

* deploy to cloudflare pages
  • Loading branch information
satelllte authored Mar 13, 2024
1 parent dbf007a commit a67312c
Show file tree
Hide file tree
Showing 13 changed files with 356 additions and 160 deletions.
1 change: 1 addition & 0 deletions .eslintrc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ extends:

rules:
tailwindcss/classnames-order: 'off' # Handled by "prettier-plugin-tailwindcss"
'@typescript-eslint/naming-convention': 'off'
11 changes: 10 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:

env:
BUILD_DIRECTORY: out
CLOUDFLARE_PROJECT_NAME: webgpu-raytracing

jobs:
ci:
Expand All @@ -35,8 +36,16 @@ jobs:
run: bun run test:format
- name: Build
run: bun run build
- name: Publish to Cloudflare Pages
uses: cloudflare/pages-action@v1
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
projectName: ${{ env.CLOUDFLARE_PROJECT_NAME }}
directory: ./${{ env.BUILD_DIRECTORY }}
gitHubToken: ${{ secrets.GITHUB_TOKEN }}
- name: Upload static output artifact
uses: actions/upload-artifact@v4
with:
name: build
path: ./${{env.BUILD_DIRECTORY}}
path: ./${{ env.BUILD_DIRECTORY }}
Binary file modified bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"format": "prettier . --write"
},
"dependencies": {
"clsx": "2.1.0",
"next": "14.0.4",
"react": "18.2.0",
"react-dom": "18.2.0"
Expand Down
26 changes: 26 additions & 0 deletions src/components/Raytracer/Button.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import clsx from 'clsx';
import {forwardRef} from 'react';

type NativeButtonProps = React.ComponentProps<'button'>;
type NativeButtonPropsToExtend = Omit<
NativeButtonProps,
'type' | 'className' | 'children'
>;
type ButtonProps = NativeButtonPropsToExtend & {
readonly children: string;
};

export const Button = forwardRef<HTMLButtonElement, ButtonProps>(
({disabled, ...rest}, forwardedRef) => (
<button
ref={forwardedRef}
type='button'
className={clsx(
'border px-4 py-1 hover:bg-zinc-900 active:bg-zinc-800',
disabled && 'opacity-50',
)}
disabled={disabled}
{...rest}
/>
),
);
48 changes: 48 additions & 0 deletions src/components/Raytracer/Canvas.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import {forwardRef, useEffect, useImperativeHandle, useRef} from 'react';

type NativeCanvasProps = React.ComponentProps<'canvas'>;
type NativeCanvasPropsToExtend = Omit<
NativeCanvasProps,
'className' | 'children'
>;
type CanvasProps = NativeCanvasPropsToExtend;

export const Canvas = forwardRef<HTMLCanvasElement, CanvasProps>(
(props, forwardedRef) => {
const innerRef = useRef<React.ElementRef<'canvas'>>(null);

useImperativeHandle(forwardedRef, () => {
if (!innerRef.current) throw new Error('innerRef is not set');
return innerRef.current;
});

useEffect(() => {
const canvas = innerRef.current;
if (!canvas) return;

const resizeCanvas = () => {
const {width, height} = canvas.getBoundingClientRect();
const scale = Math.max(window.devicePixelRatio, 1);
canvas.width = Math.floor(width * scale);
canvas.height = Math.floor(height * scale);
};

resizeCanvas();

// Using some hook like `useElementSize` based on ResizeObserver API would be actually better
// instead of listening for window resize.
// But it's still just fine for our particular use case.
window.addEventListener('resize', resizeCanvas);

return () => {
window.removeEventListener('resize', resizeCanvas);
};
}, []);

return (
<canvas ref={innerRef} className='absolute h-full w-full' {...props}>
HTML canvas is not supported in this browser
</canvas>
);
},
);
220 changes: 61 additions & 159 deletions src/components/Raytracer/Raytracer.tsx
Original file line number Diff line number Diff line change
@@ -1,189 +1,91 @@
'use client';
import {useEffect, useRef} from 'react';
import shaderWgsl from './shader.wgsl';
import {useEffect, useRef, useState} from 'react';
import {useWebGPUSupport} from './useWebGPUSupport';
import {Button} from './Button';
import {Canvas} from './Canvas';
import {Renderer} from './Renderer';
import {StatFPS} from './StatFPS';
import {StatWebGPUSupport} from './StatWebGPUSupport';

export function Raytracer() {
const webGPUSupported = useWebGPUSupport();

const canvasRef = useRef<HTMLCanvasElement>(null);
const rendererRef = useRef<Renderer>();

useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) return;
if (!canvas) throw new Error('Canvas ref is not set');

rendererRef.current = new Renderer();
void rendererRef.current.init(canvas);

const resizeCanvas = () => {
const {width, height} = canvas.getBoundingClientRect();
const scale = Math.max(window.devicePixelRatio, 1);
canvas.width = Math.floor(width * scale);
canvas.height = Math.floor(height * scale);
return () => {
rendererRef.current?.dispose();
rendererRef.current = undefined;
};
}, []);

const [running, setRunning] = useState<boolean>(false);
const animationFrameIdRef = useRef<number | undefined>();

resizeCanvas();
const lastFrameTimeMsRef = useRef<number | undefined>();
const [lastFrameTimeMs, setLastFrameTimeMs] = useState<number | undefined>();

window.addEventListener('resize', resizeCanvas);
useEffect(() => {
const intervalId = setInterval(() => {
setLastFrameTimeMs(lastFrameTimeMsRef.current);
}, 200);

return () => {
window.removeEventListener('resize', resizeCanvas);
clearInterval(intervalId);
};
}, []);

const render = async () => {
const canvas = canvasRef.current;
if (!canvas) return;
const run = () => {
if (running) return;
setRunning(true);

const {gpu} = navigator;
if (!gpu) {
showAlert('WebGPU is not supported in this browser');
return;
}
const renderer = rendererRef.current;
if (!renderer) throw new Error('Renderer is not set');

const context = canvas.getContext('webgpu');
if (!context) {
showAlert('Failed to get WebGPU context');
return;
}
let prevTime = performance.now();
const drawLoop: FrameRequestCallback = (time) => {
lastFrameTimeMsRef.current = time - prevTime;
prevTime = time;

const adapter = await gpu.requestAdapter();
if (!adapter) {
showAlert('Failed to request WebGPU adapter');
return;
}
animationFrameIdRef.current = requestAnimationFrame(drawLoop);
renderer.draw();
};

const device = await adapter.requestDevice();
animationFrameIdRef.current = requestAnimationFrame(drawLoop);
};

const start = performance.now();
draw({gpu, context, device});
const stop = () => {
if (!running) return;
setRunning(false);

const diff = performance.now() - start;
console.info('render time (ms): ', diff);
if (!animationFrameIdRef.current) return;
cancelAnimationFrame(animationFrameIdRef.current);
animationFrameIdRef.current = undefined;
lastFrameTimeMsRef.current = undefined;
};

return (
<div className='absolute inset-0 flex flex-col gap-4 p-6'>
<div className='flex flex-col gap-2 sm:flex-row sm:justify-between'>
<h1 className='text-2xl sm:text-3xl'>WebGPU raytracer</h1>
<button
type='button'
className='border px-4 py-1 hover:bg-zinc-900 active:bg-zinc-800'
onClick={render}
>
Render
</button>
<div className='absolute inset-0 flex flex-col sm:flex-row'>
<div className='flex flex-1 flex-col gap-2 p-4 sm:max-w-xs'>
<div className='flex flex-1 flex-col gap-2'>
<h1 className='text-2xl underline'>WebGPU raytracer</h1>
<StatWebGPUSupport supported={webGPUSupported} />
<StatFPS frameTimeMs={lastFrameTimeMs} />
</div>
<Button disabled={!webGPUSupported} onClick={running ? stop : run}>
{running ? 'Stop' : 'Run'}
</Button>
</div>
<div className='relative flex-1 border border-white'>
<canvas ref={canvasRef} className='absolute h-full w-full' />
<div className='relative flex-1 border-zinc-500 max-sm:border-t sm:border-l'>
<Canvas ref={canvasRef} />
</div>
</div>
);
}

const draw = ({
gpu,
context,
device,
}: {
gpu: GPU;
context: GPUCanvasContext;
device: GPUDevice;
}): void => {
const preferredCanvasFormat = gpu.getPreferredCanvasFormat();

context.configure({
device,
format: preferredCanvasFormat,
alphaMode: 'premultiplied',
});

// prettier-ignore
const vertices = new Float32Array([
/// position<vec4f> (xyzw)
-1.0, 1.0, 0.0, 1.0,
-1.0, -1.0, 0.0, 1.0,
1.0, 1.0, 0.0, 1.0,
1.0, 1.0, 0.0, 1.0,
-1.0, -1.0, 0.0, 1.0,
1.0, -1.0, 0.0, 1.0,
]);

const verticesBuffer = device.createBuffer({
label: 'vertices buffer',
size: vertices.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, // eslint-disable-line no-bitwise
});

const uniforms = new Int32Array([
context.canvas.width, /// width: u32
context.canvas.height, /// height: u32
]);

const uniformsBuffer = device.createBuffer({
label: 'uniforms buffer',
size: uniforms.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, // eslint-disable-line no-bitwise
});

const shaderModule = device.createShaderModule({
label: 'shader module',
code: shaderWgsl,
});

const renderPipeline = device.createRenderPipeline({
label: 'render pipeline',
layout: 'auto',
primitive: {topology: 'triangle-list'},
vertex: {
module: shaderModule,
entryPoint: 'vertex_main',
buffers: [
{
arrayStride: 16,
stepMode: 'vertex',
attributes: [
{shaderLocation: 0, offset: 0, format: 'float32x4'}, // Position
],
},
] as const satisfies Iterable<GPUVertexBufferLayout>,
},
fragment: {
module: shaderModule,
entryPoint: 'fragment_main',
targets: [
{format: preferredCanvasFormat},
] as const satisfies Iterable<GPUColorTargetState>,
},
});

const uniformsBindGroup = device.createBindGroup({
label: 'uniforms bind group',
layout: renderPipeline.getBindGroupLayout(0),
entries: [{binding: 0, resource: {buffer: uniformsBuffer}}],
});

const commandEncoder = device.createCommandEncoder({
label: 'command encoder',
});
const passEncoder = commandEncoder.beginRenderPass({
label: 'pass encoder',
colorAttachments: [
{
clearValue: [0.0, 0.0, 0.0, 1.0],
view: context.getCurrentTexture().createView(),
loadOp: 'clear',
storeOp: 'store',
},
] as const satisfies Iterable<GPURenderPassColorAttachment>,
});

passEncoder.setPipeline(renderPipeline);
passEncoder.setBindGroup(0, uniformsBindGroup);
passEncoder.setVertexBuffer(0, verticesBuffer);
passEncoder.draw(vertices.length / 4);
passEncoder.end();

device.queue.writeBuffer(verticesBuffer, 0, vertices);
device.queue.writeBuffer(uniformsBuffer, 0, uniforms);

device.queue.submit([commandEncoder.finish()]);
};

const showAlert = (message: string): void => {
alert(message); // eslint-disable-line no-alert
};
Loading

0 comments on commit a67312c

Please sign in to comment.