First commit.

This commit is contained in:
2025-11-19 10:42:46 +01:00
commit 22c29a8939
33 changed files with 5985 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
// x = hiddenDim, y = numClasses, z = batchSize
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let sampleIndex = id.x;
let batchSize = layer2Sizes.z;
let numClasses = layer2Sizes.y;
if (sampleIndex >= batchSize) { return; }
let base = sampleIndex * numClasses;
let label = targetLabels[sampleIndex];
var maxLogit : f32 = -1e9;
for (var c : u32 = 0u; c < numClasses; c++) {
let v = outputLogits[base + c];
if (v > maxLogit) { maxLogit = v; }
}
var sumExp : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
sumExp += exp(outputLogits[base + c] - maxLogit);
}
var loss : f32 = 0.0;
for (var c : u32 = 0u; c < numClasses; c++) {
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
let oneHot = f32(c == label);
gradientLogits[base + c] = sm - oneHot;
if (oneHot == 1.0) { loss = -log(sm); }
}
crossEntropyLoss[sampleIndex] = loss;
}