Files
WebGPU-Neural-Network-Trainer/shaders/neural/sgd_optimizer_update.wgsl

40 lines
1.5 KiB
WebGPU Shading Language
Raw Normal View History

2025-11-19 10:42:46 +01:00
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
@group(0) @binding(8) var<uniform> learningRate : f32;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if (index < arrayLength(&weightLayer1)) {
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
return;
}
let b1Index = index - arrayLength(&weightLayer1);
if (b1Index < arrayLength(&biasLayer1)) {
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
return;
}
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
if (w2Index < arrayLength(&weightLayer2)) {
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
return;
}
let b2Index = w2Index - arrayLength(&weightLayer2);
if (b2Index < arrayLength(&biasLayer2)) {
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
return;
}
}