First commit.

This commit is contained in:
2025-11-19 10:42:46 +01:00
commit 22c29a8939
33 changed files with 5985 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
// gradients_dense_layer2.wgsl
@group(0) @binding(0)
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
@group(0) @binding(1)
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
@group(0) @binding(2)
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
@group(0) @binding(3)
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
@group(0) @binding(4)
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
@group(0) @binding(5)
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
@group(0) @binding(6)
var<uniform> layer2Sizes : vec3<u32>;
// x = hiddenDim, y = numClasses, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let globalIndex = id.x;
let hiddenDim = layer2Sizes.x;
let numClasses = layer2Sizes.y;
let batchSize = layer2Sizes.z;
//
// (1) Compute gradients for W2 and b2
//
let totalWeights = hiddenDim * numClasses;
if (globalIndex < totalWeights) {
let classIndex = globalIndex / hiddenDim;
let hiddenIndex = globalIndex % hiddenDim;
var sum : f32 = 0.0;
for (var b : u32 = 0u; b < batchSize; b++) {
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
sum += activationValue * gradLogitValue;
}
gradientWeightLayer2[globalIndex] = sum;
if (hiddenIndex == 0u) {
var biasSum : f32 = 0.0;
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
biasSum += gradientLogits[b2 * numClasses + classIndex];
}
gradientBiasLayer2[classIndex] = biasSum;
}
}
//
// (2) Compute gradientHidden for layer1
//
let dhIndex = globalIndex - totalWeights;
if (dhIndex < batchSize * hiddenDim) {
let b = dhIndex / hiddenDim;
let hiddenIndex = dhIndex % hiddenDim;
var sum2 : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
let gradLogitValue = gradientLogits[b * numClasses + c];
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
sum2 += gradLogitValue * weightValue;
}
gradientHidden[dhIndex] = sum2;
}
}