Files
WebGPU-FFT-Ocean-Demo/shaders/fft_col.wgsl

95 lines
2.2 KiB
WebGPU Shading Language
Raw Normal View History

2025-12-31 14:22:45 +01:00
const PI2 : f32 = 6.28318530718;
@group(0) @binding(0) var<storage, read> inputReal : array<f32>;
@group(0) @binding(1) var<storage, read> inputImag : array<f32>;
@group(0) @binding(2) var<storage, read_write> outputReal : array<f32>;
@group(0) @binding(3) var<storage, read_write> outputImag : array<f32>;
@group(0) @binding(4) var<storage, read_write> heightField : array<f32>;
@group(0) @binding(5) var<storage, read> params : array<f32>;
var<workgroup> wr : array<f32, 64>;
var<workgroup> wi : array<f32, 64>;
fn bitReverse( x : u32, bits : u32 ) -> u32 {
var n : u32 = x;
var r : u32 = 0u;
for ( var i : u32 = 0u; i < bits; i = i + 1u ) {
r = ( r << 1u ) | ( n & 1u );
n = n >> 1u;
}
return r;
}
@compute @workgroup_size( 64 )
fn main( @builtin(local_invocation_id) lid : vec3<u32>,
@builtin(workgroup_id) gid : vec3<u32> ) {
let N : u32 = u32( params[ 1u ] );
if ( N != 64u ) {
return;
}
let col : u32 = gid.x;
let row : u32 = lid.x;
let idx : u32 = row * N + col;
let bits : u32 = 6u;
let rev : u32 = bitReverse( row, bits );
let srcIdx : u32 = rev * N + col;
wr[ row ] = inputReal[ srcIdx ];
wi[ row ] = inputImag[ srcIdx ];
workgroupBarrier();
var step : u32 = 1u;
while ( step < N ) {
let jump : u32 = step << 1u;
let twiddleAngle : f32 = -PI2 / f32( jump );
let pairIndex : u32 = ( row / jump ) * jump + ( row % step );
let matchIndex : u32 = pairIndex + step;
let k : u32 = row % step;
let angle : f32 = twiddleAngle * f32( k );
let c : f32 = cos( angle );
let s : f32 = sin( angle );
let er : f32 = wr[ pairIndex ];
let ei : f32 = wi[ pairIndex ];
let or : f32 = wr[ matchIndex ];
let oi : f32 = wi[ matchIndex ];
let tr : f32 = c * or - s * oi;
let ti : f32 = s * or + c * oi;
if ( row % jump < step ) {
wr[ pairIndex ] = er + tr;
wi[ pairIndex ] = ei + ti;
} else {
wr[ matchIndex ] = er - tr;
wi[ matchIndex ] = ei - ti;
}
workgroupBarrier();
step = jump;
}
let invScale : f32 = 1.0 / f32( N );
let realVal : f32 = wr[ row ] * invScale;
let imagVal : f32 = wi[ row ] * invScale;
outputReal[ idx ] = realVal;
outputImag[ idx ] = imagVal;
// final height is real part
heightField[ idx ] = realVal;
}