Files
WebGPU-Neural-Network-Trainer/shaders/neural/gradients_dense_layer1.wgsl

86 lines
2.6 KiB
WebGPU Shading Language
Raw Permalink Normal View History

2025-11-19 10:42:46 +01:00
@group(0) @binding(0)
var<storage, read> inputImages : array<f32>;
// [batchSize * inputDimension]
@group(0) @binding(1)
var<storage, read> gradientHidden : array<f32>;
// [batchSize * hiddenDimension]
@group(0) @binding(2)
var<storage, read_write> gradientWeightLayer1 : array<f32>;
// [inputDimension * hiddenDimension]
@group(0) @binding(3)
var<storage, read_write> gradientBiasLayer1 : array<f32>;
// [hiddenDimension]
@group(0) @binding(4)
var<uniform> layer1Sizes : vec3<u32>;
// x = inputDimension, y = hiddenDimension, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let index = id.x;
let inputDimension = layer1Sizes.x;
let hiddenDimension = layer1Sizes.y;
let batchSize = layer1Sizes.z;
// We launch:
// inputDimension * hiddenDimension threads for weight gradients
// + hiddenDimension threads for bias gradients
let totalWeightElements = inputDimension * hiddenDimension;
let totalBiasElements = hiddenDimension;
let totalElements = totalWeightElements + totalBiasElements;
if (index >= totalElements) {
return;
}
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
if (index < totalWeightElements) {
// Map flat index into (inputIndex, hiddenIndex)
let hiddenIndex : u32 = index / inputDimension;
let inputIndex : u32 = index % inputDimension;
var sum : f32 = 0.0;
// dW1[inputIndex, hiddenIndex] = sum over batch of
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
sum = sum + inputValue * hiddenGradient;
}
gradientWeightLayer1[ index ] = sum;
return;
}
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
let biasIndex : u32 = index - totalWeightElements;
if (biasIndex < hiddenDimension) {
var sumBias : f32 = 0.0;
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
sumBias = sumBias + hiddenGradient;
}
gradientBiasLayer1[ biasIndex ] = sumBias;
}
}