@group(0) @binding(0) var outputLogits : array; @group(0) @binding(1) var targetLabels : array; @group(0) @binding(2) var gradientLogits : array; @group(0) @binding(3) var crossEntropyLoss : array; @group(0) @binding(4) var layer2Sizes : vec3; // x = hiddenDim, y = numClasses, z = batchSize @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) id: vec3) { let sampleIndex = id.x; let batchSize = layer2Sizes.z; let numClasses = layer2Sizes.y; if (sampleIndex >= batchSize) { return; } let base = sampleIndex * numClasses; let label = targetLabels[sampleIndex]; var maxLogit : f32 = -1e9; for (var c : u32 = 0u; c < numClasses; c++) { let v = outputLogits[base + c]; if (v > maxLogit) { maxLogit = v; } } var sumExp : f32 = 0.0; for (var c : u32 = 0u; c < numClasses; c++) { sumExp += exp(outputLogits[base + c] - maxLogit); } var loss : f32 = 0.0; for (var c : u32 = 0u; c < numClasses; c++) { let sm = exp(outputLogits[base + c] - maxLogit) / sumExp; let oneHot = f32(c == label); gradientLogits[base + c] = sm - oneHot; if (oneHot == 1.0) { loss = -log(sm); } } crossEntropyLoss[sampleIndex] = loss; }