[WebGPU] Fix up demos on and add compute demo to webkit.org/demos
[WebKit-https.git] / Websites / webkit.org / demos / webgpu / scripts / compute-blur.js
1 const threadsPerThreadgroup = 32;
2
3 const sourceBufferBindingNum = 0;
4 const outputBufferBindingNum = 1;
5 const uniformsBufferBindingNum = 2;
6
7 // Enough space to store 1 radius and 33 weights.
8 const maxUniformsSize = (32 + 2) * Float32Array.BYTES_PER_ELEMENT;
9
10 let image, context2d, device;
11
12 const width = 600;
13
14 async function init() {
15     if (!navigator.gpu) {
16         document.body.className = "error";
17         return;
18     }
19
20     const slider = document.querySelector("input");
21     const canvas = document.querySelector("canvas");
22     context2d = canvas.getContext("2d");
23
24     const adapter = await navigator.gpu.requestAdapter();
25     device = await adapter.requestDevice();
26     image = await loadImage(canvas);
27
28     setUpCompute();
29
30     let busy = false;
31     let inputQueue = [];
32     slider.oninput = async () => {
33         inputQueue.push(slider.value);
34         
35         if (busy)
36             return;
37
38         busy = true;
39         while (inputQueue.length != 0)
40             await computeBlur(inputQueue.shift());
41         busy = false;
42     };
43 }
44
45 async function loadImage(canvas) {
46     /* Image */
47     const image = new Image();
48     const imageLoadPromise = new Promise(resolve => { 
49         image.onload = () => resolve(); 
50         image.src = "resources/safari-alpha.png"
51     });
52     await Promise.resolve(imageLoadPromise);
53
54     canvas.height = width;
55     canvas.width = width;
56
57     context2d.drawImage(image, 0, 0, width, width);
58
59     return image;
60 }
61
62 let originalData, imageSize;
63 let originalBuffer, storageBuffer, resultsBuffer, uniformsBuffer;
64 let horizontalBindGroup, verticalBindGroup, horizontalPipeline, verticalPipeline;
65
66 function setUpCompute() {
67     originalData = context2d.getImageData(0, 0, image.width, image.height);
68     imageSize = originalData.data.length;
69
70     // Buffer creation
71     let originalArrayBuffer;
72     [originalBuffer, originalArrayBuffer] = device.createBufferMapped({ size: imageSize, usage: GPUBufferUsage.STORAGE });
73     const imageWriteArray = new Uint8ClampedArray(originalArrayBuffer);
74     imageWriteArray.set(originalData.data);
75     originalBuffer.unmap();
76
77     storageBuffer = device.createBuffer({ size: imageSize, usage: GPUBufferUsage.STORAGE });
78     resultsBuffer = device.createBuffer({ size: imageSize, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.MAP_READ });
79     uniformsBuffer = device.createBuffer({ size: maxUniformsSize, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.MAP_WRITE });
80
81     // Bind buffers to kernel   
82     const bindGroupLayout = device.createBindGroupLayout({
83         bindings: [{
84             binding: sourceBufferBindingNum,
85             visibility: GPUShaderStageBit.COMPUTE,
86             type: "storage-buffer"
87         }, {
88             binding: outputBufferBindingNum,
89             visibility: GPUShaderStageBit.COMPUTE,
90             type: "storage-buffer"
91         }, {
92             binding: uniformsBufferBindingNum,
93             visibility: GPUShaderStageBit.COMPUTE,
94             type: "uniform-buffer"
95         }]
96     });
97
98     horizontalBindGroup = device.createBindGroup({
99         layout: bindGroupLayout,
100         bindings: [{
101             binding: sourceBufferBindingNum,
102             resource: {
103                 buffer: originalBuffer,
104                 size: imageSize
105             }
106         }, {
107             binding: outputBufferBindingNum,
108             resource: {
109                 buffer: storageBuffer,
110                 size: imageSize
111             }
112         }, {
113             binding: uniformsBufferBindingNum,
114             resource: {
115                 buffer: uniformsBuffer,
116                 size: maxUniformsSize
117             }
118         }]
119     });
120
121     verticalBindGroup = device.createBindGroup({
122         layout: bindGroupLayout,
123         bindings: [{
124             binding: sourceBufferBindingNum,
125             resource: {
126                 buffer: storageBuffer,
127                 size: imageSize
128             }
129         }, {
130             binding: outputBufferBindingNum,
131             resource: {
132                 buffer: resultsBuffer,
133                 size: imageSize
134             }
135         }, {
136             binding: uniformsBufferBindingNum,
137             resource: {
138                 buffer: uniformsBuffer,
139                 size: maxUniformsSize
140             }
141         }]
142     });
143
144     // Set up pipelines
145     const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] });
146
147     const shaderModule = device.createShaderModule({ code: createShaderCode(image), isWHLSL: true });
148
149     horizontalPipeline = device.createComputePipeline({ 
150         layout: pipelineLayout, 
151         computeStage: {
152             module: shaderModule,
153             entryPoint: "horizontal"
154         }
155     });
156
157     verticalPipeline = device.createComputePipeline({
158         layout: pipelineLayout,
159         computeStage: {
160             module: shaderModule,
161             entryPoint: "vertical"
162         }
163     });
164 }
165
166 async function computeBlur(radius) {
167     if (radius == 0) {
168         context2d.drawImage(image, 0, 0, width, width);
169         return;
170     }
171     const setUniformsPromise = setUniforms(radius);
172     const uniformsMappingPromise = uniformsBuffer.mapWriteAsync();
173
174     const [uniforms, uniformsArrayBuffer] = await Promise.all([setUniformsPromise, uniformsMappingPromise]);
175
176     const uniformsWriteArray = new Float32Array(uniformsArrayBuffer);
177     uniformsWriteArray.set(uniforms);
178     uniformsBuffer.unmap();
179
180     // Run horizontal pass first
181     const commandEncoder = device.createCommandEncoder();
182     const passEncoder = commandEncoder.beginComputePass();
183     passEncoder.setBindGroup(0, horizontalBindGroup);
184     passEncoder.setPipeline(horizontalPipeline);
185     const numXGroups = Math.ceil(image.width / threadsPerThreadgroup);
186     passEncoder.dispatch(numXGroups, image.height, 1);
187     passEncoder.endPass();
188
189     // Run vertical pass
190     const verticalPassEncoder = commandEncoder.beginComputePass();
191     verticalPassEncoder.setBindGroup(0, verticalBindGroup);
192     verticalPassEncoder.setPipeline(verticalPipeline);
193     const numYGroups = Math.ceil(image.height / threadsPerThreadgroup);
194     verticalPassEncoder.dispatch(image.width, numYGroups, 1);
195     verticalPassEncoder.endPass();
196
197     device.getQueue().submit([commandEncoder.finish()]);
198
199     // Draw resultsBuffer as imageData back into context2d
200     const resultArrayBuffer = await resultsBuffer.mapReadAsync();
201     const resultArray = new Uint8ClampedArray(resultArrayBuffer);
202     context2d.putImageData(new ImageData(resultArray, image.width, image.height), 0, 0);
203     resultsBuffer.unmap();
204 }
205
206 window.addEventListener("load", init);
207
208 /* Helpers */
209
210 let uniformsCache = new Map();
211
212 async function setUniforms(radius)
213 {
214     let uniforms = uniformsCache.get(radius);
215     if (uniforms != undefined)
216         return uniforms;
217
218     const sigma = radius / 2.0;
219     const twoSigma2 = 2.0 * sigma * sigma;
220
221     uniforms = [radius];
222     let weightSum = 0;
223
224     for (let i = 0; i <= radius; ++i) {
225         const weight = Math.exp(-i * i / twoSigma2);
226         uniforms.push(weight);
227         weightSum += (i == 0) ? weight : weight * 2;
228     }
229
230     // Compensate for loss in brightness
231     const brightnessScale =  1 - (0.1 / 32.0) * radius;
232     weightSum *= brightnessScale;
233     for (let i = 1; i < uniforms.length; ++i)
234         uniforms[i] /= weightSum;
235         
236     uniformsCache.set(radius, uniforms);
237
238     return uniforms;
239 }
240
241 const byteMask = (1 << 8) - 1;
242
243 function createShaderCode(image) {
244     return `
245 uint getR(uint rgba)
246 {
247     return rgba & ${byteMask};
248 }
249
250 uint getG(uint rgba)
251 {
252     return (rgba >> 8) & ${byteMask};
253 }
254
255 uint getB(uint rgba)
256 {
257     return (rgba >> 16) & ${byteMask};
258 }
259
260 uint getA(uint rgba)
261 {
262     return (rgba >> 24) & ${byteMask};
263 }
264
265 uint makeRGBA(uint r, uint g, uint b, uint a)
266 {
267     return r + (g << 8) + (b << 16) + (a << 24);
268 }
269
270 void accumulateChannels(thread uint[] channels, uint startColor, float weight)
271 {
272     channels[0] += uint(float(getR(startColor)) * weight);
273     channels[1] += uint(float(getG(startColor)) * weight);
274     channels[2] += uint(float(getB(startColor)) * weight);
275     channels[3] += uint(float(getA(startColor)) * weight);
276
277     // Compensate for brightness-adjusted weights.
278     if (channels[0] > 255)
279         channels[0] = 255;
280
281     if (channels[1] > 255)
282         channels[1] = 255;
283
284     if (channels[2] > 255)
285         channels[2] = 255;
286
287     if (channels[3] > 255)
288         channels[3] = 255;
289 }
290
291 uint horizontallyOffsetIndex(uint index, int offset, int rowStart, int rowEnd)
292 {
293     int offsetIndex = int(index) + offset;
294
295     if (offsetIndex < rowStart || offsetIndex >= rowEnd)
296         return index;
297     
298     return uint(offsetIndex);
299 }
300
301 uint verticallyOffsetIndex(uint index, int offset, uint length)
302 {
303     int realOffset = offset * ${image.width};
304     int offsetIndex = int(index) + realOffset;
305
306     if (offsetIndex < 0 || offsetIndex >= int(length))
307         return index;
308     
309     return uint(offsetIndex);
310 }
311
312 [numthreads(${threadsPerThreadgroup}, 1, 1)]
313 compute void horizontal(constant uint[] source : register(u${sourceBufferBindingNum}),
314                         device uint[] output : register(u${outputBufferBindingNum}),
315                         constant float[] uniforms : register(b${uniformsBufferBindingNum}),
316                         float3 dispatchThreadID : SV_DispatchThreadID)
317 {
318     int radius = int(uniforms[0]);
319     int rowStart = ${image.width} * int(dispatchThreadID.y);
320     int rowEnd = ${image.width} * (1 + int(dispatchThreadID.y));
321     uint globalIndex = uint(rowStart) + uint(dispatchThreadID.x);
322
323     uint[4] channels;
324
325     for (int i = -radius; i <= radius; ++i) {
326         uint startColor = source[horizontallyOffsetIndex(globalIndex, i, rowStart, rowEnd)];
327         float weight = uniforms[uint(abs(i) + 1)];
328         accumulateChannels(@channels, startColor, weight);
329     }
330
331     output[globalIndex] = makeRGBA(channels[0], channels[1], channels[2], channels[3]);
332 }
333
334 [numthreads(1, ${threadsPerThreadgroup}, 1)]
335 compute void vertical(constant uint[] source : register(u${sourceBufferBindingNum}),
336                         device uint[] output : register(u${outputBufferBindingNum}),
337                         constant float[] uniforms : register(b${uniformsBufferBindingNum}),
338                         float3 dispatchThreadID : SV_DispatchThreadID)
339 {
340     int radius = int(uniforms[0]);
341     uint globalIndex = uint(dispatchThreadID.x) * ${image.height} + uint(dispatchThreadID.y);
342
343     uint[4] channels;
344
345     for (int i = -radius; i <= radius; ++i) {
346         uint startColor = source[verticallyOffsetIndex(globalIndex, i, source.length)];
347         float weight = uniforms[uint(abs(i) + 1)];
348         accumulateChannels(@channels, startColor, weight);
349     }
350
351     output[globalIndex] = makeRGBA(channels[0], channels[1], channels[2], channels[3]);
352 }
353 `;
354 }