86 lines
2.6 KiB
WebGPU Shading Language
86 lines
2.6 KiB
WebGPU Shading Language
@group(0) @binding(0)
|
|
var<storage, read> inputImages : array<f32>;
|
|
// [batchSize * inputDimension]
|
|
|
|
@group(0) @binding(1)
|
|
var<storage, read> gradientHidden : array<f32>;
|
|
// [batchSize * hiddenDimension]
|
|
|
|
@group(0) @binding(2)
|
|
var<storage, read_write> gradientWeightLayer1 : array<f32>;
|
|
// [inputDimension * hiddenDimension]
|
|
|
|
@group(0) @binding(3)
|
|
var<storage, read_write> gradientBiasLayer1 : array<f32>;
|
|
// [hiddenDimension]
|
|
|
|
@group(0) @binding(4)
|
|
var<uniform> layer1Sizes : vec3<u32>;
|
|
// x = inputDimension, y = hiddenDimension, z = batchSize
|
|
|
|
|
|
@compute @workgroup_size(64)
|
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
|
|
|
let index = id.x;
|
|
|
|
let inputDimension = layer1Sizes.x;
|
|
let hiddenDimension = layer1Sizes.y;
|
|
let batchSize = layer1Sizes.z;
|
|
|
|
// We launch:
|
|
// inputDimension * hiddenDimension threads for weight gradients
|
|
// + hiddenDimension threads for bias gradients
|
|
let totalWeightElements = inputDimension * hiddenDimension;
|
|
let totalBiasElements = hiddenDimension;
|
|
let totalElements = totalWeightElements + totalBiasElements;
|
|
|
|
if (index >= totalElements) {
|
|
return;
|
|
}
|
|
|
|
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
|
|
if (index < totalWeightElements) {
|
|
|
|
// Map flat index into (inputIndex, hiddenIndex)
|
|
let hiddenIndex : u32 = index / inputDimension;
|
|
let inputIndex : u32 = index % inputDimension;
|
|
|
|
var sum : f32 = 0.0;
|
|
|
|
// dW1[inputIndex, hiddenIndex] = sum over batch of
|
|
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
|
|
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
|
|
|
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
|
|
|
|
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
|
|
|
|
sum = sum + inputValue * hiddenGradient;
|
|
|
|
}
|
|
|
|
gradientWeightLayer1[ index ] = sum;
|
|
return;
|
|
}
|
|
|
|
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
|
|
let biasIndex : u32 = index - totalWeightElements;
|
|
|
|
if (biasIndex < hiddenDimension) {
|
|
|
|
var sumBias : f32 = 0.0;
|
|
|
|
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
|
|
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
|
|
|
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
|
|
|
|
sumBias = sumBias + hiddenGradient;
|
|
|
|
}
|
|
|
|
gradientBiasLayer1[ biasIndex ] = sumBias;
|
|
}
|
|
}
|