First commit.
This commit is contained in:
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
@@ -0,0 +1,39 @@
|
||||
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
|
||||
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
|
||||
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||||
// x = hiddenDim, y = numClasses, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let sampleIndex = id.x;
|
||||
let batchSize = layer2Sizes.z;
|
||||
let numClasses = layer2Sizes.y;
|
||||
|
||||
if (sampleIndex >= batchSize) { return; }
|
||||
|
||||
let base = sampleIndex * numClasses;
|
||||
let label = targetLabels[sampleIndex];
|
||||
|
||||
var maxLogit : f32 = -1e9;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let v = outputLogits[base + c];
|
||||
if (v > maxLogit) { maxLogit = v; }
|
||||
}
|
||||
|
||||
var sumExp : f32 = 0.0;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
sumExp += exp(outputLogits[base + c] - maxLogit);
|
||||
}
|
||||
|
||||
var loss : f32 = 0.0;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
|
||||
let oneHot = f32(c == label);
|
||||
gradientLogits[base + c] = sm - oneHot;
|
||||
if (oneHot == 1.0) { loss = -log(sm); }
|
||||
}
|
||||
|
||||
crossEntropyLoss[sampleIndex] = loss;
|
||||
}
|
||||
Reference in New Issue
Block a user