@group(0) @binding(0) var weightLayer1 : array; @group(0) @binding(1) var biasLayer1 : array; @group(0) @binding(2) var gradientWeightLayer1 : array; @group(0) @binding(3) var gradientBiasLayer1 : array; @group(0) @binding(4) var weightLayer2 : array; @group(0) @binding(5) var biasLayer2 : array; @group(0) @binding(6) var gradientWeightLayer2 : array; @group(0) @binding(7) var gradientBiasLayer2 : array; @group(0) @binding(8) var learningRate : f32; @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) id: vec3) { let index = id.x; if (index < arrayLength(&weightLayer1)) { weightLayer1[index] -= gradientWeightLayer1[index] * learningRate; return; } let b1Index = index - arrayLength(&weightLayer1); if (b1Index < arrayLength(&biasLayer1)) { biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate; return; } let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1); if (w2Index < arrayLength(&weightLayer2)) { weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate; return; } let b2Index = w2Index - arrayLength(&weightLayer2); if (b2Index < arrayLength(&biasLayer2)) { biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate; return; } }