Initial commit
This commit is contained in:
64
Demos/sort/shaders/bitonicSort.wgsl
Normal file
64
Demos/sort/shaders/bitonicSort.wgsl
Normal file
@@ -0,0 +1,64 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> compare: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> indices: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> k: u32; // current stage size (power of two)
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> j: u32; // current subsequence size
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> totalCount: u32;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let idx = global_id.x;
|
||||
|
||||
|
||||
let ixj = idx ^ j;
|
||||
|
||||
|
||||
if (idx >= totalCount || ixj >= totalCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
if (ixj > idx && ixj < totalCount) {
|
||||
let ascending = (idx & k) == 0u;
|
||||
|
||||
let dist_idx = compare[idx];
|
||||
let dist_ixj = compare[ixj];
|
||||
|
||||
var swap = false;
|
||||
|
||||
if (ascending) {
|
||||
if (dist_idx < dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
} else {
|
||||
if (dist_idx > dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (swap) {
|
||||
|
||||
let tempDist = compare[idx];
|
||||
let tempDist2 = compare[ixj];
|
||||
|
||||
let tempIndex = indices[idx];
|
||||
let tempIndex2 = indices[ixj];
|
||||
|
||||
compare[idx] = tempDist2;
|
||||
compare[ixj] = tempDist;
|
||||
|
||||
indices[idx] = tempIndex2;
|
||||
indices[ixj] = tempIndex;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
53
Demos/sort/shaders/bitonicSortUInt.wgsl
Normal file
53
Demos/sort/shaders/bitonicSortUInt.wgsl
Normal file
@@ -0,0 +1,53 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> compare: array<u32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> indices: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> k: u32;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> j: u32;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> totalCount: u32;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let idx = global_id.x;
|
||||
|
||||
let ixj = idx ^ j;
|
||||
|
||||
if (idx >= totalCount || ixj <= idx || ixj >= totalCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ixj > idx) {
|
||||
let ascending = (idx & k) == 0u;
|
||||
|
||||
let dist_idx = compare[idx];
|
||||
let dist_ixj = compare[ixj];
|
||||
|
||||
var swap = false;
|
||||
|
||||
if (ascending) {
|
||||
if (dist_idx > dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
} else {
|
||||
if (dist_idx < dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (swap) {
|
||||
|
||||
let tempDist = compare[idx];
|
||||
|
||||
compare[idx] = compare[ixj];
|
||||
compare[ixj] = tempDist;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
68
Demos/sort/shaders/bitonicSortUIntMultiPass.wgsl
Normal file
68
Demos/sort/shaders/bitonicSortUIntMultiPass.wgsl
Normal file
@@ -0,0 +1,68 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> compare: array<u32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> threadPassIndices: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> kArray: array<u32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> jArray: array<u32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> totalCount: u32;
|
||||
|
||||
|
||||
|
||||
@compute @workgroup_size( 256 )
|
||||
fn main( @builtin(global_invocation_id) global_id : vec3<u32> ) {
|
||||
|
||||
let idx = global_id.x;
|
||||
|
||||
let threadPassIndex = threadPassIndices[ idx ];
|
||||
|
||||
threadPassIndices[idx] = threadPassIndices[idx] + 1;
|
||||
|
||||
let j = jArray[ threadPassIndex ];
|
||||
|
||||
let k = kArray[ threadPassIndex ];
|
||||
|
||||
let ixj = idx ^ j;
|
||||
|
||||
|
||||
|
||||
if (idx >= totalCount || ixj <= idx || ixj >= totalCount) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ixj > idx) {
|
||||
let ascending = (idx & k) == 0u;
|
||||
|
||||
let dist_idx = compare[idx];
|
||||
let dist_ixj = compare[ixj];
|
||||
|
||||
var swap = false;
|
||||
|
||||
if (ascending) {
|
||||
if (dist_idx > dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
} else {
|
||||
if (dist_idx < dist_ixj) {
|
||||
swap = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (swap) {
|
||||
|
||||
let tempDist = compare[idx];
|
||||
|
||||
compare[idx] = compare[ixj];
|
||||
compare[ixj] = tempDist;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
68
Demos/sort/shaders/localSort.wgsl
Normal file
68
Demos/sort/shaders/localSort.wgsl
Normal file
@@ -0,0 +1,68 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> compare: array<u32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> totalCount: u32;
|
||||
|
||||
var<workgroup> sharedData: array<u32, 256>;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(local_invocation_id) local_id : vec3<u32>,
|
||||
@builtin(global_invocation_id) global_id : vec3<u32>) {
|
||||
|
||||
let localIndex = local_id.x;
|
||||
let globalIndex = global_id.x;
|
||||
|
||||
// Load element from global memory into shared memory if in range
|
||||
if (globalIndex < totalCount) {
|
||||
sharedData[localIndex] = compare[globalIndex];
|
||||
} else {
|
||||
sharedData[localIndex] = 0xffffffffu; // Max uint to push invalid values to the end
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Bitonic sort in shared memory on 256 elements
|
||||
var size = 2u;
|
||||
while (size <= 256u) {
|
||||
var stride = size >> 1u;
|
||||
|
||||
var j = stride;
|
||||
while (j > 0u) {
|
||||
let ixj = localIndex ^ j;
|
||||
|
||||
if (ixj > localIndex) {
|
||||
let ascending = ((localIndex & size) == 0u);
|
||||
|
||||
let valLocal = sharedData[localIndex];
|
||||
let valIxj = sharedData[ixj];
|
||||
|
||||
var swap = false;
|
||||
if (ascending) {
|
||||
if (valLocal > valIxj) {
|
||||
swap = true;
|
||||
}
|
||||
} else {
|
||||
if (valLocal < valIxj) {
|
||||
swap = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (swap) {
|
||||
sharedData[localIndex] = valIxj;
|
||||
sharedData[ixj] = valLocal;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
j = j >> 1u;
|
||||
}
|
||||
|
||||
size = size << 1u;
|
||||
}
|
||||
|
||||
// Write sorted results back to global memory
|
||||
if (globalIndex < totalCount) {
|
||||
compare[globalIndex] = sharedData[localIndex];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user