82 lines
2.4 KiB
WebGPU Shading Language
82 lines
2.4 KiB
WebGPU Shading Language
|
|
// gradients_dense_layer2.wgsl
|
||
|
|
|
||
|
|
@group(0) @binding(0)
|
||
|
|
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
|
||
|
|
|
||
|
|
@group(0) @binding(1)
|
||
|
|
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
|
||
|
|
|
||
|
|
@group(0) @binding(2)
|
||
|
|
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
|
||
|
|
|
||
|
|
@group(0) @binding(3)
|
||
|
|
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
|
||
|
|
|
||
|
|
@group(0) @binding(4)
|
||
|
|
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
|
||
|
|
|
||
|
|
@group(0) @binding(5)
|
||
|
|
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
|
||
|
|
|
||
|
|
@group(0) @binding(6)
|
||
|
|
var<uniform> layer2Sizes : vec3<u32>;
|
||
|
|
// x = hiddenDim, y = numClasses, z = batchSize
|
||
|
|
|
||
|
|
@compute @workgroup_size(64)
|
||
|
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||
|
|
|
||
|
|
let globalIndex = id.x;
|
||
|
|
|
||
|
|
let hiddenDim = layer2Sizes.x;
|
||
|
|
let numClasses = layer2Sizes.y;
|
||
|
|
let batchSize = layer2Sizes.z;
|
||
|
|
|
||
|
|
//
|
||
|
|
// (1) Compute gradients for W2 and b2
|
||
|
|
//
|
||
|
|
let totalWeights = hiddenDim * numClasses;
|
||
|
|
|
||
|
|
if (globalIndex < totalWeights) {
|
||
|
|
let classIndex = globalIndex / hiddenDim;
|
||
|
|
let hiddenIndex = globalIndex % hiddenDim;
|
||
|
|
|
||
|
|
var sum : f32 = 0.0;
|
||
|
|
|
||
|
|
for (var b : u32 = 0u; b < batchSize; b++) {
|
||
|
|
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
|
||
|
|
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
|
||
|
|
sum += activationValue * gradLogitValue;
|
||
|
|
}
|
||
|
|
|
||
|
|
gradientWeightLayer2[globalIndex] = sum;
|
||
|
|
|
||
|
|
if (hiddenIndex == 0u) {
|
||
|
|
var biasSum : f32 = 0.0;
|
||
|
|
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
|
||
|
|
biasSum += gradientLogits[b2 * numClasses + classIndex];
|
||
|
|
}
|
||
|
|
gradientBiasLayer2[classIndex] = biasSum;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
//
|
||
|
|
// (2) Compute gradientHidden for layer1
|
||
|
|
//
|
||
|
|
let dhIndex = globalIndex - totalWeights;
|
||
|
|
|
||
|
|
if (dhIndex < batchSize * hiddenDim) {
|
||
|
|
let b = dhIndex / hiddenDim;
|
||
|
|
let hiddenIndex = dhIndex % hiddenDim;
|
||
|
|
|
||
|
|
var sum2 : f32 = 0.0;
|
||
|
|
|
||
|
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||
|
|
let gradLogitValue = gradientLogits[b * numClasses + c];
|
||
|
|
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
|
||
|
|
sum2 += gradLogitValue * weightValue;
|
||
|
|
}
|
||
|
|
|
||
|
|
gradientHidden[dhIndex] = sum2;
|
||
|
|
}
|
||
|
|
}
|