commit 5d2cbdb6ea9e07fd27ef50bb40c2dec93ff5c74c Author: kaj dijkstra Date: Thu Dec 25 10:34:12 2025 +0100 First Commit diff --git a/demo.js b/demo.js new file mode 100644 index 0000000..0e65143 --- /dev/null +++ b/demo.js @@ -0,0 +1,170 @@ + +import Shader from "../framework/WebGpu.js" + +import Measure from "../framework/Measure.js" + + +const adapter = await navigator.gpu.requestAdapter(); + +if (!adapter ) { + + throw new Error("Failed to get GPU adapter"); + +} + +const particleCount = 131072; + +const workgroupSize = 256; + +const numWorkgroups = Math.ceil( particleCount / workgroupSize ); + +const device = await adapter.requestDevice(); + + +var passIndex = 0; + +var jArray = new Array(); + +var kArray = new Array(); + + +for ( let k = 512; k <= particleCount; k <<= 1 ) { + + for ( let j = k >> 1; j > 0; j >>= 1 ) { + + jArray.push( j ); + + kArray.push( k ); + + } + +} + +const indices = new Uint32Array( particleCount ); + +const threadPassIndices = new Uint32Array( particleCount ); + +for (var i = 0; i < particleCount; i++) { + + indices[particleCount-i-1] = i; + + threadPassIndices[0] = 0; + +} + +const localSortShader = new Shader( device, "./shaders/localSort.wgsl"); + +await localSortShader.addStage("main", GPUShaderStage.COMPUTE ); + +localSortShader.setVariable( "compare", indices ); + +localSortShader.setVariable( "totalCount", particleCount ); + + +const bitonicSortGridHashShader = new Shader( device, "./shaders/bitonicSortUIntMultiPass.wgsl"); + +await bitonicSortGridHashShader.addStage("main", GPUShaderStage.COMPUTE ); + +bitonicSortGridHashShader.setVariable( "totalCount", particleCount ); + +bitonicSortGridHashShader.setBuffer( "compare", localSortShader.getBuffer("compare") ); + +bitonicSortGridHashShader.setVariable( "jArray", jArray ); + +bitonicSortGridHashShader.setVariable( "kArray", kArray ); + + +const measure = new Measure(); + +measure.writeToPage = true; + +measure.element = document.querySelector(".result"); + +for (var i = 3; i < 59; i++) { + + bitonicSortGridHashShader.setVariable("threadPassIndices", threadPassIndices); + + measure.start( "sort"+i ); + + //await localSortShader.execute( numWorkgroups ); + +const commandEncoder = device.createCommandEncoder(); + + +// === Local sort === +{ + const passEncoder = commandEncoder.beginComputePass(); + + passEncoder.setPipeline( localSortShader.pipeline ); + + for ( const [ index, bindGroup ] of localSortShader.bindGroups.entries() ) { + + passEncoder.setBindGroup( index, bindGroup ); + + } + + passEncoder.dispatchWorkgroups( numWorkgroups ); + + passEncoder.end(); +} + +// === Bitonic global merge === +{ + const passEncoder = commandEncoder.beginComputePass(); + + passEncoder.setPipeline( bitonicSortGridHashShader.pipeline ); + + for ( const [ index, bindGroup ] of bitonicSortGridHashShader.bindGroups.entries() ) { + + passEncoder.setBindGroup( index, bindGroup ); + + } + + var numberOfPasses = 0; + + for ( let k = 512; k <= particleCount; k <<= 1 ) { + + for ( let j = k >> 1; j > 0; j >>= 1 ) { + + numberOfPasses++; + + passEncoder.dispatchWorkgroups( numWorkgroups ); + + } + + } + + passEncoder.end(); + +} + + +const commandBuffer = commandEncoder.finish(); + +device.queue.submit( [ commandBuffer ] ); + + await device.queue.onSubmittedWorkDone(); + + measure.end( "sort"+i ); + +} + + +var result = await bitonicSortGridHashShader.debugBuffer("compare"); + +await bitonicSortGridHashShader.debugBuffer("threadPassIndices"); + +var error = false; + +for (var i = 0; i < particleCount; i++ ) { + + if( result[i] != i ) { + + error = true; + + } + +} + +console.log("There is error?", error, particleCount); + diff --git a/index.html b/index.html new file mode 100644 index 0000000..686ded0 --- /dev/null +++ b/index.html @@ -0,0 +1,232 @@ + + + + + + +WebGPU Attention Example + + + +
+ +
+ +

LocalSort -> Bitonic Sort Multi step

+

131 072 items under 10 ms in WebGpu

+

Open Your console to see the results.

+ +
+ + + + \ No newline at end of file diff --git a/shaders/bitonicSort.wgsl b/shaders/bitonicSort.wgsl new file mode 100644 index 0000000..ce4319b --- /dev/null +++ b/shaders/bitonicSort.wgsl @@ -0,0 +1,64 @@ +@group(0) @binding(0) +var compare: array; + +@group(0) @binding(1) +var indices: array; + +@group(0) @binding(2) +var k: u32; // current stage size (power of two) + +@group(0) @binding(3) +var j: u32; // current subsequence size + +@group(0) @binding(4) +var totalCount: u32; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + + + let ixj = idx ^ j; + + + if (idx >= totalCount || ixj >= totalCount) { + return; + } + + + + if (ixj > idx && ixj < totalCount) { + let ascending = (idx & k) == 0u; + + let dist_idx = compare[idx]; + let dist_ixj = compare[ixj]; + + var swap = false; + + if (ascending) { + if (dist_idx < dist_ixj) { + swap = true; + } + } else { + if (dist_idx > dist_ixj) { + swap = true; + } + } + + if (swap) { + + let tempDist = compare[idx]; + let tempDist2 = compare[ixj]; + + let tempIndex = indices[idx]; + let tempIndex2 = indices[ixj]; + + compare[idx] = tempDist2; + compare[ixj] = tempDist; + + indices[idx] = tempIndex2; + indices[ixj] = tempIndex; + + } + } +} diff --git a/shaders/bitonicSortUInt.wgsl b/shaders/bitonicSortUInt.wgsl new file mode 100644 index 0000000..55c8ffe --- /dev/null +++ b/shaders/bitonicSortUInt.wgsl @@ -0,0 +1,53 @@ +@group(0) @binding(0) +var compare: array; + +@group(0) @binding(1) +var indices: array; + +@group(0) @binding(2) +var k: u32; + +@group(0) @binding(3) +var j: u32; + +@group(0) @binding(4) +var totalCount: u32; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + + let ixj = idx ^ j; + + if (idx >= totalCount || ixj <= idx || ixj >= totalCount) { + return; + } + + if (ixj > idx) { + let ascending = (idx & k) == 0u; + + let dist_idx = compare[idx]; + let dist_ixj = compare[ixj]; + + var swap = false; + + if (ascending) { + if (dist_idx > dist_ixj) { + swap = true; + } + } else { + if (dist_idx < dist_ixj) { + swap = true; + } + } + + if (swap) { + + let tempDist = compare[idx]; + + compare[idx] = compare[ixj]; + compare[ixj] = tempDist; + + } + } +} diff --git a/shaders/bitonicSortUIntMultiPass.wgsl b/shaders/bitonicSortUIntMultiPass.wgsl new file mode 100644 index 0000000..9608516 --- /dev/null +++ b/shaders/bitonicSortUIntMultiPass.wgsl @@ -0,0 +1,68 @@ +@group(0) @binding(0) +var compare: array; + +@group(0) @binding(1) +var threadPassIndices: array; + +@group(0) @binding(2) +var kArray: array; + +@group(0) @binding(3) +var jArray: array; + +@group(0) @binding(4) +var totalCount: u32; + + + +@compute @workgroup_size( 256 ) +fn main( @builtin(global_invocation_id) global_id : vec3 ) { + + let idx = global_id.x; + + let threadPassIndex = threadPassIndices[ idx ]; + + threadPassIndices[idx] = threadPassIndices[idx] + 1; + + let j = jArray[ threadPassIndex ]; + + let k = kArray[ threadPassIndex ]; + + let ixj = idx ^ j; + + + + if (idx >= totalCount || ixj <= idx || ixj >= totalCount) { + return; + } + + if (ixj > idx) { + let ascending = (idx & k) == 0u; + + let dist_idx = compare[idx]; + let dist_ixj = compare[ixj]; + + var swap = false; + + if (ascending) { + if (dist_idx > dist_ixj) { + swap = true; + } + } else { + if (dist_idx < dist_ixj) { + swap = true; + } + } + + if (swap) { + + let tempDist = compare[idx]; + + compare[idx] = compare[ixj]; + compare[ixj] = tempDist; + + } + } + + +} diff --git a/shaders/localSort.wgsl b/shaders/localSort.wgsl new file mode 100644 index 0000000..c3cde39 --- /dev/null +++ b/shaders/localSort.wgsl @@ -0,0 +1,68 @@ +@group(0) @binding(0) +var compare: array; + +@group(0) @binding(1) +var totalCount: u32; + +var sharedData: array; + +@compute @workgroup_size(256) +fn main(@builtin(local_invocation_id) local_id : vec3, + @builtin(global_invocation_id) global_id : vec3) { + + let localIndex = local_id.x; + let globalIndex = global_id.x; + + // Load element from global memory into shared memory if in range + if (globalIndex < totalCount) { + sharedData[localIndex] = compare[globalIndex]; + } else { + sharedData[localIndex] = 0xffffffffu; // Max uint to push invalid values to the end + } + + workgroupBarrier(); + + // Bitonic sort in shared memory on 256 elements + var size = 2u; + while (size <= 256u) { + var stride = size >> 1u; + + var j = stride; + while (j > 0u) { + let ixj = localIndex ^ j; + + if (ixj > localIndex) { + let ascending = ((localIndex & size) == 0u); + + let valLocal = sharedData[localIndex]; + let valIxj = sharedData[ixj]; + + var swap = false; + if (ascending) { + if (valLocal > valIxj) { + swap = true; + } + } else { + if (valLocal < valIxj) { + swap = true; + } + } + + if (swap) { + sharedData[localIndex] = valIxj; + sharedData[ixj] = valLocal; + } + } + + workgroupBarrier(); + j = j >> 1u; + } + + size = size << 1u; + } + + // Write sorted results back to global memory + if (globalIndex < totalCount) { + compare[globalIndex] = sharedData[localIndex]; + } +} \ No newline at end of file