Initial commit

This commit is contained in:
2025-11-17 15:06:39 +01:00
parent 2015516eca
commit 4d2432a1c2
91 changed files with 5074894 additions and 2 deletions

170
Demos/sort/demo.js Normal file
View File

@@ -0,0 +1,170 @@
import Shader from "../../framework/WebGpu.js"
import Measure from "../../framework/Measure.js"
const adapter = await navigator.gpu.requestAdapter();
if (!adapter ) {
throw new Error("Failed to get GPU adapter");
}
const particleCount = 131072;
const workgroupSize = 256;
const numWorkgroups = Math.ceil( particleCount / workgroupSize );
const device = await adapter.requestDevice();
var passIndex = 0;
var jArray = new Array();
var kArray = new Array();
for ( let k = 512; k <= particleCount; k <<= 1 ) {
for ( let j = k >> 1; j > 0; j >>= 1 ) {
jArray.push( j );
kArray.push( k );
}
}
const indices = new Uint32Array( particleCount );
const threadPassIndices = new Uint32Array( particleCount );
for (var i = 0; i < particleCount; i++) {
indices[particleCount-i-1] = i;
threadPassIndices[0] = 0;
}
const localSortShader = new Shader( device, "./shaders/localSort.wgsl");
await localSortShader.addStage("main", GPUShaderStage.COMPUTE );
localSortShader.setVariable( "compare", indices );
localSortShader.setVariable( "totalCount", particleCount );
const bitonicSortGridHashShader = new Shader( device, "./shaders/bitonicSortUIntMultiPass.wgsl");
await bitonicSortGridHashShader.addStage("main", GPUShaderStage.COMPUTE );
bitonicSortGridHashShader.setVariable( "totalCount", particleCount );
bitonicSortGridHashShader.setBuffer( "compare", localSortShader.getBuffer("compare") );
bitonicSortGridHashShader.setVariable( "jArray", jArray );
bitonicSortGridHashShader.setVariable( "kArray", kArray );
const measure = new Measure();
measure.writeToPage = true;
measure.element = document.querySelector(".result");
for (var i = 3; i < 59; i++) {
bitonicSortGridHashShader.setVariable("threadPassIndices", threadPassIndices);
measure.start( "sort"+i );
//await localSortShader.execute( numWorkgroups );
const commandEncoder = device.createCommandEncoder();
// === Local sort ===
{
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline( localSortShader.pipeline );
for ( const [ index, bindGroup ] of localSortShader.bindGroups.entries() ) {
passEncoder.setBindGroup( index, bindGroup );
}
passEncoder.dispatchWorkgroups( numWorkgroups );
passEncoder.end();
}
// === Bitonic global merge ===
{
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline( bitonicSortGridHashShader.pipeline );
for ( const [ index, bindGroup ] of bitonicSortGridHashShader.bindGroups.entries() ) {
passEncoder.setBindGroup( index, bindGroup );
}
var numberOfPasses = 0;
for ( let k = 512; k <= particleCount; k <<= 1 ) {
for ( let j = k >> 1; j > 0; j >>= 1 ) {
numberOfPasses++;
passEncoder.dispatchWorkgroups( numWorkgroups );
}
}
passEncoder.end();
}
const commandBuffer = commandEncoder.finish();
device.queue.submit( [ commandBuffer ] );
await device.queue.onSubmittedWorkDone();
measure.end( "sort"+i );
}
var result = await bitonicSortGridHashShader.debugBuffer("compare");
await bitonicSortGridHashShader.debugBuffer("threadPassIndices");
var error = false;
for (var i = 0; i < particleCount; i++ ) {
if( result[i] != i ) {
error = true;
}
}
console.log("There is error?", error, particleCount);

206
Demos/sort/index.html Normal file
View File

@@ -0,0 +1,206 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" href="data:;base64,iVBORw0KGgo=" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>WebGPU Attention Example</title>
<style>
html, body {
margin: 0;
padding: 0;
height: 100%;
}
.result{
color: #d9d9d9;
margin:0 auto;
width: 400px;
padding: 100px;
text-align: center;
}
button {
top: 10px;
left: 10px;
z-index: 10;
padding: 10px 20px;
font-size: 16px;
font-weight: 600;
color: #f0f0f5;
background: rgba(28, 28, 30, 0.8);
border: none;
border-radius: 12px;
box-shadow:
0 1px 3px rgba(0, 0, 0, 0.4),
0 0 8px rgba(28, 28, 30, 0.7);
backdrop-filter: blur(12px);
-webkit-backdrop-filter: blur(12px);
cursor: pointer;
transition: background-color 0.3s ease, box-shadow 0.3s ease;
}
button:hover {
background: rgba(28, 28, 30, 0.95);
box-shadow:
0 4px 12px rgba(0, 0, 0, 0.6),
0 0 15px rgba(28, 28, 30, 0.9);
}
button:active {
background: rgba(28, 28, 30, 1);
box-shadow:
0 2px 6px rgba(0, 0, 0, 0.8),
0 0 10px rgba(28, 28, 30, 1);
}
#controlPanel {
position: absolute;
top: 10px;
left: 10px;
display: flex;
flex-direction: column;
gap: 12px;
padding: 20px;
background: rgba(30, 30, 30, 0.6);
backdrop-filter: blur(18px);
-webkit-backdrop-filter: blur(18px);
border-radius: 14px;
box-shadow:
0 4px 12px rgba(0, 0, 0, 0.5),
0 0 0 1px rgba(255, 255, 255, 0.05);
z-index: 1000;
box-sizing: border-box;
}
.inputRow {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
width: 100%;
}
.inputRow label {
color: #ccc;
font-size: 14px;
white-space: nowrap;
width: 70px;
}
.inputRow button,
.inputRow input {
flex-grow: 1;
font-size: 14px;
font-weight: 600;
color: #f0f0f5;
background: rgba(28, 28, 30, 0.9);
border: none;
border-radius: 8px;
padding: 8px 10px;
cursor: pointer;
box-shadow:
0 1px 3px rgba(0, 0, 0, 0.4),
0 0 6px rgba(28, 28, 30, 0.5);
transition: background-color 0.2s ease, box-shadow 0.2s ease;
}
.inputRow button:hover {
background: rgba(40, 40, 44, 0.95);
}
.inputRow input {
background: rgba(20, 20, 20, 0.8);
border: 1px solid rgba(255, 255, 255, 0.1);
outline: none;
box-shadow: inset 0 1px 2px rgba(0, 0, 0, 0.6);
}
.inputRow select {
flex-grow: 1;
font-size: 14px;
font-weight: 600;
color: #f0f0f5;
background: rgba(28, 28, 30, 0.9);
border: none;
border-radius: 8px;
padding: 8px 10px;
cursor: pointer;
box-shadow:
0 1px 3px rgba(0, 0, 0, 0.4),
0 0 6px rgba(28, 28, 30, 0.5);
transition: background-color 0.2s ease, box-shadow 0.2s ease;
appearance: none; /* Remove default arrow */
-webkit-appearance: none;
-moz-appearance: none;
background-image:
linear-gradient(45deg, transparent 50%, #f0f0f5 50%),
linear-gradient(135deg, #f0f0f5 50%, transparent 50%);
background-position:
calc(100% - 20px) calc(50% - 3px),
calc(100% - 15px) calc(50% - 3px);
background-size: 5px 5px;
background-repeat: no-repeat;
}
.inputRow select:hover {
background: rgba(40, 40, 44, 0.95);
}
.inputRow option {
background: rgba(28, 28, 30, 0.95);
color: #f0f0f5;
}
.top-panel{
justify-content: flex-start;
border-bottom: 1px solid #292929;
display: flex;
gap: 2.5rem;
padding: 1.75rem 2rem;
background: rgb(20 20 20 / 95%);
backdrop-filter: blur(35px);
-webkit-backdrop-filter: blur(35px);
color: white;
box-shadow: 0 4px 30px rgba(0, 0, 0, 0.7);
font-weight: 600;
font-size: 0.95rem;
user-select: none;
width: calc(100vw);
z-index: 10;
}
.top-panel button{
cursor:pointer;
}
body {
background: #111111;
font-family: "Inter", sans-serif;
}
</style>
<base href="sort/" />
</head>
<body>
<div class="top-panel"></div>
<div class="result">
<p>LocalSort -> Bitonic Sort Multi step</p>
<p> 131 072 items under 10 ms in WebGpu</p>
<p>Open Your console to see the results.</p>
</div>
<canvas width="1000" height="1000"></canvas>
<script type="module" src="demo.js"></script>
</body>
</html>

View File

@@ -0,0 +1,64 @@
@group(0) @binding(0)
var<storage, read_write> compare: array<f32>;
@group(0) @binding(1)
var<storage, read_write> indices: array<u32>;
@group(0) @binding(2)
var<uniform> k: u32; // current stage size (power of two)
@group(0) @binding(3)
var<uniform> j: u32; // current subsequence size
@group(0) @binding(4)
var<uniform> totalCount: u32;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
let ixj = idx ^ j;
if (idx >= totalCount || ixj >= totalCount) {
return;
}
if (ixj > idx && ixj < totalCount) {
let ascending = (idx & k) == 0u;
let dist_idx = compare[idx];
let dist_ixj = compare[ixj];
var swap = false;
if (ascending) {
if (dist_idx < dist_ixj) {
swap = true;
}
} else {
if (dist_idx > dist_ixj) {
swap = true;
}
}
if (swap) {
let tempDist = compare[idx];
let tempDist2 = compare[ixj];
let tempIndex = indices[idx];
let tempIndex2 = indices[ixj];
compare[idx] = tempDist2;
compare[ixj] = tempDist;
indices[idx] = tempIndex2;
indices[ixj] = tempIndex;
}
}
}

View File

@@ -0,0 +1,53 @@
@group(0) @binding(0)
var<storage, read_write> compare: array<u32>;
@group(0) @binding(1)
var<storage, read_write> indices: array<u32>;
@group(0) @binding(2)
var<uniform> k: u32;
@group(0) @binding(3)
var<uniform> j: u32;
@group(0) @binding(4)
var<uniform> totalCount: u32;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
let ixj = idx ^ j;
if (idx >= totalCount || ixj <= idx || ixj >= totalCount) {
return;
}
if (ixj > idx) {
let ascending = (idx & k) == 0u;
let dist_idx = compare[idx];
let dist_ixj = compare[ixj];
var swap = false;
if (ascending) {
if (dist_idx > dist_ixj) {
swap = true;
}
} else {
if (dist_idx < dist_ixj) {
swap = true;
}
}
if (swap) {
let tempDist = compare[idx];
compare[idx] = compare[ixj];
compare[ixj] = tempDist;
}
}
}

View File

@@ -0,0 +1,68 @@
@group(0) @binding(0)
var<storage, read_write> compare: array<u32>;
@group(0) @binding(1)
var<storage, read_write> threadPassIndices: array<u32>;
@group(0) @binding(2)
var<storage, read_write> kArray: array<u32>;
@group(0) @binding(3)
var<storage, read_write> jArray: array<u32>;
@group(0) @binding(4)
var<uniform> totalCount: u32;
@compute @workgroup_size( 256 )
fn main( @builtin(global_invocation_id) global_id : vec3<u32> ) {
let idx = global_id.x;
let threadPassIndex = threadPassIndices[ idx ];
threadPassIndices[idx] = threadPassIndices[idx] + 1;
let j = jArray[ threadPassIndex ];
let k = kArray[ threadPassIndex ];
let ixj = idx ^ j;
if (idx >= totalCount || ixj <= idx || ixj >= totalCount) {
return;
}
if (ixj > idx) {
let ascending = (idx & k) == 0u;
let dist_idx = compare[idx];
let dist_ixj = compare[ixj];
var swap = false;
if (ascending) {
if (dist_idx > dist_ixj) {
swap = true;
}
} else {
if (dist_idx < dist_ixj) {
swap = true;
}
}
if (swap) {
let tempDist = compare[idx];
compare[idx] = compare[ixj];
compare[ixj] = tempDist;
}
}
}

View File

@@ -0,0 +1,68 @@
@group(0) @binding(0)
var<storage, read_write> compare: array<u32>;
@group(0) @binding(1)
var<uniform> totalCount: u32;
var<workgroup> sharedData: array<u32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(global_invocation_id) global_id : vec3<u32>) {
let localIndex = local_id.x;
let globalIndex = global_id.x;
// Load element from global memory into shared memory if in range
if (globalIndex < totalCount) {
sharedData[localIndex] = compare[globalIndex];
} else {
sharedData[localIndex] = 0xffffffffu; // Max uint to push invalid values to the end
}
workgroupBarrier();
// Bitonic sort in shared memory on 256 elements
var size = 2u;
while (size <= 256u) {
var stride = size >> 1u;
var j = stride;
while (j > 0u) {
let ixj = localIndex ^ j;
if (ixj > localIndex) {
let ascending = ((localIndex & size) == 0u);
let valLocal = sharedData[localIndex];
let valIxj = sharedData[ixj];
var swap = false;
if (ascending) {
if (valLocal > valIxj) {
swap = true;
}
} else {
if (valLocal < valIxj) {
swap = true;
}
}
if (swap) {
sharedData[localIndex] = valIxj;
sharedData[ixj] = valLocal;
}
}
workgroupBarrier();
j = j >> 1u;
}
size = size << 1u;
}
// Write sorted results back to global memory
if (globalIndex < totalCount) {
compare[globalIndex] = sharedData[localIndex];
}
}