// gradients_dense_layer2.wgsl @group(0) @binding(0) var hiddenActivations : array; // [batch * hiddenDim] @group(0) @binding(1) var gradientLogits : array; // [batch * numClasses] @group(0) @binding(2) var weightLayer2 : array; // [hiddenDim * numClasses] <-- Missing before @group(0) @binding(3) var gradientWeightLayer2 : array; // [hiddenDim * numClasses] @group(0) @binding(4) var gradientBiasLayer2 : array; // [numClasses] @group(0) @binding(5) var gradientHidden : array; // [batch * hiddenDim] @group(0) @binding(6) var layer2Sizes : vec3; // x = hiddenDim, y = numClasses, z = batchSize @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) id : vec3) { let globalIndex = id.x; let hiddenDim = layer2Sizes.x; let numClasses = layer2Sizes.y; let batchSize = layer2Sizes.z; // // (1) Compute gradients for W2 and b2 // let totalWeights = hiddenDim * numClasses; if (globalIndex < totalWeights) { let classIndex = globalIndex / hiddenDim; let hiddenIndex = globalIndex % hiddenDim; var sum : f32 = 0.0; for (var b : u32 = 0u; b < batchSize; b++) { let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex]; let gradLogitValue = gradientLogits[b * numClasses + classIndex]; sum += activationValue * gradLogitValue; } gradientWeightLayer2[globalIndex] = sum; if (hiddenIndex == 0u) { var biasSum : f32 = 0.0; for (var b2 : u32 = 0u; b2 < batchSize; b2++) { biasSum += gradientLogits[b2 * numClasses + classIndex]; } gradientBiasLayer2[classIndex] = biasSum; } } // // (2) Compute gradientHidden for layer1 // let dhIndex = globalIndex - totalWeights; if (dhIndex < batchSize * hiddenDim) { let b = dhIndex / hiddenDim; let hiddenIndex = dhIndex % hiddenDim; var sum2 : f32 = 0.0; for (var c : u32 = 0u; c < numClasses; c++) { let gradLogitValue = gradientLogits[b * numClasses + c]; let weightValue = weightLayer2[c * hiddenDim + hiddenIndex]; sum2 += gradLogitValue * weightValue; } gradientHidden[dhIndex] = sum2; } }