@group(0) @binding(0) var hiddenActivations : array; @group(0) @binding(1) var weightLayer2 : array; @group(0) @binding(2) var biasLayer2 : array; @group(0) @binding(3) var outputLogits : array; @group(0) @binding(4) var layer2Sizes : vec3; // x = hiddenDim, y = numClasses, z = batchSize @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) id : vec3) { let index = id.x; let hiddenDim = layer2Sizes.x; let numClasses = layer2Sizes.y; let batchSize = layer2Sizes.z; let total = batchSize * numClasses; if (index >= total) { return; } let sampleIndex = index / numClasses; let classIndex = index % numClasses; var sum : f32 = biasLayer2[classIndex]; let baseHiddenIndex = sampleIndex * hiddenDim; let baseWeightIndex = classIndex * hiddenDim; for (var i : u32 = 0u; i < hiddenDim; i++) { sum += hiddenActivations[baseHiddenIndex + i] * weightLayer2[baseWeightIndex + i]; } outputLogits[index] = sum; }