34 lines
1.1 KiB
WebGPU Shading Language
34 lines
1.1 KiB
WebGPU Shading Language
@group(0) @binding(0) var<storage, read> hiddenActivations : array<f32>;
|
|
@group(0) @binding(1) var<storage, read_write> weightLayer2 : array<f32>;
|
|
@group(0) @binding(2) var<storage, read_write> biasLayer2 : array<f32>;
|
|
@group(0) @binding(3) var<storage, read_write> outputLogits : array<f32>;
|
|
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
|
// x = hiddenDim, y = numClasses, z = batchSize
|
|
|
|
@compute @workgroup_size(64)
|
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
|
let index = id.x;
|
|
|
|
let hiddenDim = layer2Sizes.x;
|
|
let numClasses = layer2Sizes.y;
|
|
let batchSize = layer2Sizes.z;
|
|
|
|
let total = batchSize * numClasses;
|
|
if (index >= total) { return; }
|
|
|
|
let sampleIndex = index / numClasses;
|
|
let classIndex = index % numClasses;
|
|
|
|
var sum : f32 = biasLayer2[classIndex];
|
|
|
|
let baseHiddenIndex = sampleIndex * hiddenDim;
|
|
let baseWeightIndex = classIndex * hiddenDim;
|
|
|
|
for (var i : u32 = 0u; i < hiddenDim; i++) {
|
|
sum += hiddenActivations[baseHiddenIndex + i] *
|
|
weightLayer2[baseWeightIndex + i];
|
|
}
|
|
|
|
outputLogits[index] = sum;
|
|
}
|