@group(0) @binding(0) var inputImages : array; // [batchSize * inputDimension] @group(0) @binding(1) var gradientHidden : array; // [batchSize * hiddenDimension] @group(0) @binding(2) var gradientWeightLayer1 : array; // [inputDimension * hiddenDimension] @group(0) @binding(3) var gradientBiasLayer1 : array; // [hiddenDimension] @group(0) @binding(4) var layer1Sizes : vec3; // x = inputDimension, y = hiddenDimension, z = batchSize @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) id : vec3) { let index = id.x; let inputDimension = layer1Sizes.x; let hiddenDimension = layer1Sizes.y; let batchSize = layer1Sizes.z; // We launch: // inputDimension * hiddenDimension threads for weight gradients // + hiddenDimension threads for bias gradients let totalWeightElements = inputDimension * hiddenDimension; let totalBiasElements = hiddenDimension; let totalElements = totalWeightElements + totalBiasElements; if (index >= totalElements) { return; } // ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ---- if (index < totalWeightElements) { // Map flat index into (inputIndex, hiddenIndex) let hiddenIndex : u32 = index / inputDimension; let inputIndex : u32 = index % inputDimension; var sum : f32 = 0.0; // dW1[inputIndex, hiddenIndex] = sum over batch of // inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex] for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) { let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ]; let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ]; sum = sum + inputValue * hiddenGradient; } gradientWeightLayer1[ index ] = sum; return; } // ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ---- let biasIndex : u32 = index - totalWeightElements; if (biasIndex < hiddenDimension) { var sumBias : f32 = 0.0; // db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex] for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) { let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ]; sumBias = sumBias + hiddenGradient; } gradientBiasLayer1[ biasIndex ] = sumBias; } }