1 const threadsPerThreadgroup = 32;
3 const sourceBufferBindingNum = 0;
4 const outputBufferBindingNum = 1;
5 const uniformsBufferBindingNum = 2;
7 // Enough space to store 1 radius and 33 weights.
8 const maxUniformsSize = (32 + 2) * Float32Array.BYTES_PER_ELEMENT;
10 let image, context2d, device;
14 async function init() {
16 document.body.className = "error";
20 const slider = document.querySelector("input");
21 const canvas = document.querySelector("canvas");
22 context2d = canvas.getContext("2d");
24 const adapter = await navigator.gpu.requestAdapter();
25 device = await adapter.requestDevice();
26 image = await loadImage(canvas);
32 slider.oninput = async () => {
33 inputQueue.push(slider.value);
39 while (inputQueue.length != 0)
40 await computeBlur(inputQueue.shift());
45 async function loadImage(canvas) {
47 const image = new Image();
48 const imageLoadPromise = new Promise(resolve => {
49 image.onload = () => resolve();
50 image.src = "resources/safari-alpha.png"
52 await Promise.resolve(imageLoadPromise);
54 canvas.height = width;
57 context2d.drawImage(image, 0, 0, width, width);
62 let originalData, imageSize;
63 let originalBuffer, storageBuffer, resultsBuffer, uniformsBuffer;
64 let horizontalBindGroup, verticalBindGroup, horizontalPipeline, verticalPipeline;
66 function setUpCompute() {
67 originalData = context2d.getImageData(0, 0, image.width, image.height);
68 imageSize = originalData.data.length;
71 let originalArrayBuffer;
72 [originalBuffer, originalArrayBuffer] = device.createBufferMapped({ size: imageSize, usage: GPUBufferUsage.STORAGE });
73 const imageWriteArray = new Uint8ClampedArray(originalArrayBuffer);
74 imageWriteArray.set(originalData.data);
75 originalBuffer.unmap();
77 storageBuffer = device.createBuffer({ size: imageSize, usage: GPUBufferUsage.STORAGE });
78 resultsBuffer = device.createBuffer({ size: imageSize, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.MAP_READ });
79 uniformsBuffer = device.createBuffer({ size: maxUniformsSize, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.MAP_WRITE });
81 // Bind buffers to kernel
82 const bindGroupLayout = device.createBindGroupLayout({
84 binding: sourceBufferBindingNum,
85 visibility: GPUShaderStageBit.COMPUTE,
86 type: "storage-buffer"
88 binding: outputBufferBindingNum,
89 visibility: GPUShaderStageBit.COMPUTE,
90 type: "storage-buffer"
92 binding: uniformsBufferBindingNum,
93 visibility: GPUShaderStageBit.COMPUTE,
94 type: "uniform-buffer"
98 horizontalBindGroup = device.createBindGroup({
99 layout: bindGroupLayout,
101 binding: sourceBufferBindingNum,
103 buffer: originalBuffer,
107 binding: outputBufferBindingNum,
109 buffer: storageBuffer,
113 binding: uniformsBufferBindingNum,
115 buffer: uniformsBuffer,
116 size: maxUniformsSize
121 verticalBindGroup = device.createBindGroup({
122 layout: bindGroupLayout,
124 binding: sourceBufferBindingNum,
126 buffer: storageBuffer,
130 binding: outputBufferBindingNum,
132 buffer: resultsBuffer,
136 binding: uniformsBufferBindingNum,
138 buffer: uniformsBuffer,
139 size: maxUniformsSize
145 const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] });
147 const shaderModule = device.createShaderModule({ code: createShaderCode(image), isWHLSL: true });
149 horizontalPipeline = device.createComputePipeline({
150 layout: pipelineLayout,
152 module: shaderModule,
153 entryPoint: "horizontal"
157 verticalPipeline = device.createComputePipeline({
158 layout: pipelineLayout,
160 module: shaderModule,
161 entryPoint: "vertical"
166 async function computeBlur(radius) {
168 context2d.drawImage(image, 0, 0, width, width);
171 const setUniformsPromise = setUniforms(radius);
172 const uniformsMappingPromise = uniformsBuffer.mapWriteAsync();
174 const [uniforms, uniformsArrayBuffer] = await Promise.all([setUniformsPromise, uniformsMappingPromise]);
176 const uniformsWriteArray = new Float32Array(uniformsArrayBuffer);
177 uniformsWriteArray.set(uniforms);
178 uniformsBuffer.unmap();
180 // Run horizontal pass first
181 const commandEncoder = device.createCommandEncoder();
182 const passEncoder = commandEncoder.beginComputePass();
183 passEncoder.setBindGroup(0, horizontalBindGroup);
184 passEncoder.setPipeline(horizontalPipeline);
185 const numXGroups = Math.ceil(image.width / threadsPerThreadgroup);
186 passEncoder.dispatch(numXGroups, image.height, 1);
187 passEncoder.endPass();
190 const verticalPassEncoder = commandEncoder.beginComputePass();
191 verticalPassEncoder.setBindGroup(0, verticalBindGroup);
192 verticalPassEncoder.setPipeline(verticalPipeline);
193 const numYGroups = Math.ceil(image.height / threadsPerThreadgroup);
194 verticalPassEncoder.dispatch(image.width, numYGroups, 1);
195 verticalPassEncoder.endPass();
197 device.getQueue().submit([commandEncoder.finish()]);
199 // Draw resultsBuffer as imageData back into context2d
200 const resultArrayBuffer = await resultsBuffer.mapReadAsync();
201 const resultArray = new Uint8ClampedArray(resultArrayBuffer);
202 context2d.putImageData(new ImageData(resultArray, image.width, image.height), 0, 0);
203 resultsBuffer.unmap();
206 window.addEventListener("load", init);
210 let uniformsCache = new Map();
212 async function setUniforms(radius)
214 let uniforms = uniformsCache.get(radius);
215 if (uniforms != undefined)
218 const sigma = radius / 2.0;
219 const twoSigma2 = 2.0 * sigma * sigma;
224 for (let i = 0; i <= radius; ++i) {
225 const weight = Math.exp(-i * i / twoSigma2);
226 uniforms.push(weight);
227 weightSum += (i == 0) ? weight : weight * 2;
230 // Compensate for loss in brightness
231 const brightnessScale = 1 - (0.1 / 32.0) * radius;
232 weightSum *= brightnessScale;
233 for (let i = 1; i < uniforms.length; ++i)
234 uniforms[i] /= weightSum;
236 uniformsCache.set(radius, uniforms);
241 const byteMask = (1 << 8) - 1;
243 function createShaderCode(image) {
247 return rgba & ${byteMask};
252 return (rgba >> 8) & ${byteMask};
257 return (rgba >> 16) & ${byteMask};
262 return (rgba >> 24) & ${byteMask};
265 uint makeRGBA(uint r, uint g, uint b, uint a)
267 return r + (g << 8) + (b << 16) + (a << 24);
270 void accumulateChannels(thread uint[] channels, uint startColor, float weight)
272 channels[0] += uint(float(getR(startColor)) * weight);
273 channels[1] += uint(float(getG(startColor)) * weight);
274 channels[2] += uint(float(getB(startColor)) * weight);
275 channels[3] += uint(float(getA(startColor)) * weight);
277 // Compensate for brightness-adjusted weights.
278 if (channels[0] > 255)
281 if (channels[1] > 255)
284 if (channels[2] > 255)
287 if (channels[3] > 255)
291 uint horizontallyOffsetIndex(uint index, int offset, int rowStart, int rowEnd)
293 int offsetIndex = int(index) + offset;
295 if (offsetIndex < rowStart || offsetIndex >= rowEnd)
298 return uint(offsetIndex);
301 uint verticallyOffsetIndex(uint index, int offset, uint length)
303 int realOffset = offset * ${image.width};
304 int offsetIndex = int(index) + realOffset;
306 if (offsetIndex < 0 || offsetIndex >= int(length))
309 return uint(offsetIndex);
312 [numthreads(${threadsPerThreadgroup}, 1, 1)]
313 compute void horizontal(constant uint[] source : register(u${sourceBufferBindingNum}),
314 device uint[] output : register(u${outputBufferBindingNum}),
315 constant float[] uniforms : register(b${uniformsBufferBindingNum}),
316 float3 dispatchThreadID : SV_DispatchThreadID)
318 int radius = int(uniforms[0]);
319 int rowStart = ${image.width} * int(dispatchThreadID.y);
320 int rowEnd = ${image.width} * (1 + int(dispatchThreadID.y));
321 uint globalIndex = uint(rowStart) + uint(dispatchThreadID.x);
325 for (int i = -radius; i <= radius; ++i) {
326 uint startColor = source[horizontallyOffsetIndex(globalIndex, i, rowStart, rowEnd)];
327 float weight = uniforms[uint(abs(i) + 1)];
328 accumulateChannels(@channels, startColor, weight);
331 output[globalIndex] = makeRGBA(channels[0], channels[1], channels[2], channels[3]);
334 [numthreads(1, ${threadsPerThreadgroup}, 1)]
335 compute void vertical(constant uint[] source : register(u${sourceBufferBindingNum}),
336 device uint[] output : register(u${outputBufferBindingNum}),
337 constant float[] uniforms : register(b${uniformsBufferBindingNum}),
338 float3 dispatchThreadID : SV_DispatchThreadID)
340 int radius = int(uniforms[0]);
341 uint globalIndex = uint(dispatchThreadID.x) * ${image.height} + uint(dispatchThreadID.y);
345 for (int i = -radius; i <= radius; ++i) {
346 uint startColor = source[verticallyOffsetIndex(globalIndex, i, source.length)];
347 float weight = uniforms[uint(abs(i) + 1)];
348 accumulateChannels(@channels, startColor, weight);
351 output[globalIndex] = makeRGBA(channels[0], channels[1], channels[2], channels[3]);