First commit.
This commit is contained in:
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
@@ -0,0 +1,81 @@
|
||||
// gradients_dense_layer2.wgsl
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
|
||||
|
||||
@group(0) @binding(5)
|
||||
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
|
||||
|
||||
@group(0) @binding(6)
|
||||
var<uniform> layer2Sizes : vec3<u32>;
|
||||
// x = hiddenDim, y = numClasses, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||
|
||||
let globalIndex = id.x;
|
||||
|
||||
let hiddenDim = layer2Sizes.x;
|
||||
let numClasses = layer2Sizes.y;
|
||||
let batchSize = layer2Sizes.z;
|
||||
|
||||
//
|
||||
// (1) Compute gradients for W2 and b2
|
||||
//
|
||||
let totalWeights = hiddenDim * numClasses;
|
||||
|
||||
if (globalIndex < totalWeights) {
|
||||
let classIndex = globalIndex / hiddenDim;
|
||||
let hiddenIndex = globalIndex % hiddenDim;
|
||||
|
||||
var sum : f32 = 0.0;
|
||||
|
||||
for (var b : u32 = 0u; b < batchSize; b++) {
|
||||
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
|
||||
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
|
||||
sum += activationValue * gradLogitValue;
|
||||
}
|
||||
|
||||
gradientWeightLayer2[globalIndex] = sum;
|
||||
|
||||
if (hiddenIndex == 0u) {
|
||||
var biasSum : f32 = 0.0;
|
||||
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
|
||||
biasSum += gradientLogits[b2 * numClasses + classIndex];
|
||||
}
|
||||
gradientBiasLayer2[classIndex] = biasSum;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// (2) Compute gradientHidden for layer1
|
||||
//
|
||||
let dhIndex = globalIndex - totalWeights;
|
||||
|
||||
if (dhIndex < batchSize * hiddenDim) {
|
||||
let b = dhIndex / hiddenDim;
|
||||
let hiddenIndex = dhIndex % hiddenDim;
|
||||
|
||||
var sum2 : f32 = 0.0;
|
||||
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let gradLogitValue = gradientLogits[b * numClasses + c];
|
||||
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
|
||||
sum2 += gradLogitValue * weightValue;
|
||||
}
|
||||
|
||||
gradientHidden[dhIndex] = sum2;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user