First commit.

This commit is contained in:
2025-11-19 10:42:46 +01:00
commit 22c29a8939
33 changed files with 5985 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
@group(0) @binding(0) var<storage, read> inputImages : array<f32>;
@group(0) @binding(1) var<storage, read_write> weightLayer1 : array<f32>;
@group(0) @binding(2) var<storage, read_write> biasLayer1 : array<f32>;
@group(0) @binding(3) var<storage, read_write> hiddenActivations : array<f32>;
@group(0) @binding(4) var<uniform> layer1Sizes : vec3<u32>;
// x = inputDim, y = hiddenDim, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
let index = global_id.x;
let inputDim = layer1Sizes.x;
let hiddenDim = layer1Sizes.y;
let batchSize = layer1Sizes.z;
let totalElements = batchSize * hiddenDim;
if (index >= totalElements) { return; }
let sampleIndex = index / hiddenDim;
let hiddenIndex = index % hiddenDim;
var sum : f32 = biasLayer1[hiddenIndex];
let baseInputIndex = sampleIndex * inputDim;
let baseWeightIndex = hiddenIndex * inputDim;
for (var i : u32 = 0u; i < inputDim; i++) {
sum += inputImages[baseInputIndex + i] *
weightLayer1[baseWeightIndex + i];
}
hiddenActivations[index] = max(sum, 0.0);
}

View File

@@ -0,0 +1,33 @@
@group(0) @binding(0) var<storage, read> hiddenActivations : array<f32>;
@group(0) @binding(1) var<storage, read_write> weightLayer2 : array<f32>;
@group(0) @binding(2) var<storage, read_write> biasLayer2 : array<f32>;
@group(0) @binding(3) var<storage, read_write> outputLogits : array<f32>;
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
// x = hiddenDim, y = numClasses, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let index = id.x;
let hiddenDim = layer2Sizes.x;
let numClasses = layer2Sizes.y;
let batchSize = layer2Sizes.z;
let total = batchSize * numClasses;
if (index >= total) { return; }
let sampleIndex = index / numClasses;
let classIndex = index % numClasses;
var sum : f32 = biasLayer2[classIndex];
let baseHiddenIndex = sampleIndex * hiddenDim;
let baseWeightIndex = classIndex * hiddenDim;
for (var i : u32 = 0u; i < hiddenDim; i++) {
sum += hiddenActivations[baseHiddenIndex + i] *
weightLayer2[baseWeightIndex + i];
}
outputLogits[index] = sum;
}

View File

@@ -0,0 +1,85 @@
@group(0) @binding(0)
var<storage, read> inputImages : array<f32>;
// [batchSize * inputDimension]
@group(0) @binding(1)
var<storage, read> gradientHidden : array<f32>;
// [batchSize * hiddenDimension]
@group(0) @binding(2)
var<storage, read_write> gradientWeightLayer1 : array<f32>;
// [inputDimension * hiddenDimension]
@group(0) @binding(3)
var<storage, read_write> gradientBiasLayer1 : array<f32>;
// [hiddenDimension]
@group(0) @binding(4)
var<uniform> layer1Sizes : vec3<u32>;
// x = inputDimension, y = hiddenDimension, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let index = id.x;
let inputDimension = layer1Sizes.x;
let hiddenDimension = layer1Sizes.y;
let batchSize = layer1Sizes.z;
// We launch:
// inputDimension * hiddenDimension threads for weight gradients
// + hiddenDimension threads for bias gradients
let totalWeightElements = inputDimension * hiddenDimension;
let totalBiasElements = hiddenDimension;
let totalElements = totalWeightElements + totalBiasElements;
if (index >= totalElements) {
return;
}
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
if (index < totalWeightElements) {
// Map flat index into (inputIndex, hiddenIndex)
let hiddenIndex : u32 = index / inputDimension;
let inputIndex : u32 = index % inputDimension;
var sum : f32 = 0.0;
// dW1[inputIndex, hiddenIndex] = sum over batch of
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
sum = sum + inputValue * hiddenGradient;
}
gradientWeightLayer1[ index ] = sum;
return;
}
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
let biasIndex : u32 = index - totalWeightElements;
if (biasIndex < hiddenDimension) {
var sumBias : f32 = 0.0;
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
sumBias = sumBias + hiddenGradient;
}
gradientBiasLayer1[ biasIndex ] = sumBias;
}
}

View File

@@ -0,0 +1,81 @@
// gradients_dense_layer2.wgsl
@group(0) @binding(0)
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
@group(0) @binding(1)
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
@group(0) @binding(2)
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
@group(0) @binding(3)
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
@group(0) @binding(4)
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
@group(0) @binding(5)
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
@group(0) @binding(6)
var<uniform> layer2Sizes : vec3<u32>;
// x = hiddenDim, y = numClasses, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
let globalIndex = id.x;
let hiddenDim = layer2Sizes.x;
let numClasses = layer2Sizes.y;
let batchSize = layer2Sizes.z;
//
// (1) Compute gradients for W2 and b2
//
let totalWeights = hiddenDim * numClasses;
if (globalIndex < totalWeights) {
let classIndex = globalIndex / hiddenDim;
let hiddenIndex = globalIndex % hiddenDim;
var sum : f32 = 0.0;
for (var b : u32 = 0u; b < batchSize; b++) {
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
sum += activationValue * gradLogitValue;
}
gradientWeightLayer2[globalIndex] = sum;
if (hiddenIndex == 0u) {
var biasSum : f32 = 0.0;
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
biasSum += gradientLogits[b2 * numClasses + classIndex];
}
gradientBiasLayer2[classIndex] = biasSum;
}
}
//
// (2) Compute gradientHidden for layer1
//
let dhIndex = globalIndex - totalWeights;
if (dhIndex < batchSize * hiddenDim) {
let b = dhIndex / hiddenDim;
let hiddenIndex = dhIndex % hiddenDim;
var sum2 : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
let gradLogitValue = gradientLogits[b * numClasses + c];
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
sum2 += gradLogitValue * weightValue;
}
gradientHidden[dhIndex] = sum2;
}
}

View File

@@ -0,0 +1,39 @@
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
@group(0) @binding(8) var<uniform> learningRate : f32;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if (index < arrayLength(&weightLayer1)) {
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
return;
}
let b1Index = index - arrayLength(&weightLayer1);
if (b1Index < arrayLength(&biasLayer1)) {
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
return;
}
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
if (w2Index < arrayLength(&weightLayer2)) {
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
return;
}
let b2Index = w2Index - arrayLength(&weightLayer2);
if (b2Index < arrayLength(&biasLayer2)) {
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
return;
}
}

View File

@@ -0,0 +1,39 @@
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
// x = hiddenDim, y = numClasses, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let sampleIndex = id.x;
let batchSize = layer2Sizes.z;
let numClasses = layer2Sizes.y;
if (sampleIndex >= batchSize) { return; }
let base = sampleIndex * numClasses;
let label = targetLabels[sampleIndex];
var maxLogit : f32 = -1e9;
for (var c : u32 = 0u; c < numClasses; c++) {
let v = outputLogits[base + c];
if (v > maxLogit) { maxLogit = v; }
}
var sumExp : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
sumExp += exp(outputLogits[base + c] - maxLogit);
}
var loss : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
let oneHot = f32(c == label);
gradientLogits[base + c] = sm - oneHot;
if (oneHot == 1.0) { loss = -log(sm); }
}
crossEntropyLoss[sampleIndex] = loss;
}