34 lines
1.1 KiB
WebGPU Shading Language
34 lines
1.1 KiB
WebGPU Shading Language
|
|
@group(0) @binding(0) var<storage, read> inputImages : array<f32>;
|
||
|
|
@group(0) @binding(1) var<storage, read_write> weightLayer1 : array<f32>;
|
||
|
|
@group(0) @binding(2) var<storage, read_write> biasLayer1 : array<f32>;
|
||
|
|
@group(0) @binding(3) var<storage, read_write> hiddenActivations : array<f32>;
|
||
|
|
@group(0) @binding(4) var<uniform> layer1Sizes : vec3<u32>;
|
||
|
|
// x = inputDim, y = hiddenDim, z = batchSize
|
||
|
|
|
||
|
|
@compute @workgroup_size(64)
|
||
|
|
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
|
||
|
|
let index = global_id.x;
|
||
|
|
|
||
|
|
let inputDim = layer1Sizes.x;
|
||
|
|
let hiddenDim = layer1Sizes.y;
|
||
|
|
let batchSize = layer1Sizes.z;
|
||
|
|
|
||
|
|
let totalElements = batchSize * hiddenDim;
|
||
|
|
if (index >= totalElements) { return; }
|
||
|
|
|
||
|
|
let sampleIndex = index / hiddenDim;
|
||
|
|
let hiddenIndex = index % hiddenDim;
|
||
|
|
|
||
|
|
var sum : f32 = biasLayer1[hiddenIndex];
|
||
|
|
|
||
|
|
let baseInputIndex = sampleIndex * inputDim;
|
||
|
|
let baseWeightIndex = hiddenIndex * inputDim;
|
||
|
|
|
||
|
|
for (var i : u32 = 0u; i < inputDim; i++) {
|
||
|
|
sum += inputImages[baseInputIndex + i] *
|
||
|
|
weightLayer1[baseWeightIndex + i];
|
||
|
|
}
|
||
|
|
|
||
|
|
hiddenActivations[index] = max(sum, 0.0);
|
||
|
|
}
|