40 lines
1.3 KiB
WebGPU Shading Language
40 lines
1.3 KiB
WebGPU Shading Language
|
|
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
|
||
|
|
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
|
||
|
|
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
|
||
|
|
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
|
||
|
|
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||
|
|
// x = hiddenDim, y = numClasses, z = batchSize
|
||
|
|
|
||
|
|
@compute @workgroup_size(64)
|
||
|
|
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||
|
|
let sampleIndex = id.x;
|
||
|
|
let batchSize = layer2Sizes.z;
|
||
|
|
let numClasses = layer2Sizes.y;
|
||
|
|
|
||
|
|
if (sampleIndex >= batchSize) { return; }
|
||
|
|
|
||
|
|
let base = sampleIndex * numClasses;
|
||
|
|
let label = targetLabels[sampleIndex];
|
||
|
|
|
||
|
|
var maxLogit : f32 = -1e9;
|
||
|
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||
|
|
let v = outputLogits[base + c];
|
||
|
|
if (v > maxLogit) { maxLogit = v; }
|
||
|
|
}
|
||
|
|
|
||
|
|
var sumExp : f32 = 0.0;
|
||
|
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||
|
|
sumExp += exp(outputLogits[base + c] - maxLogit);
|
||
|
|
}
|
||
|
|
|
||
|
|
var loss : f32 = 0.0;
|
||
|
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||
|
|
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
|
||
|
|
let oneHot = f32(c == label);
|
||
|
|
gradientLogits[base + c] = sm - oneHot;
|
||
|
|
if (oneHot == 1.0) { loss = -log(sm); }
|
||
|
|
}
|
||
|
|
|
||
|
|
crossEntropyLoss[sampleIndex] = loss;
|
||
|
|
}
|