@group(0) @binding(0) var inputImages : array; @group(0) @binding(1) var weightLayer1 : array; @group(0) @binding(2) var biasLayer1 : array; @group(0) @binding(3) var hiddenActivations : array; @group(0) @binding(4) var layer1Sizes : vec3; // x = inputDim, y = hiddenDim, z = batchSize @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) global_id : vec3) { let index = global_id.x; let inputDim = layer1Sizes.x; let hiddenDim = layer1Sizes.y; let batchSize = layer1Sizes.z; let totalElements = batchSize * hiddenDim; if (index >= totalElements) { return; } let sampleIndex = index / hiddenDim; let hiddenIndex = index % hiddenDim; var sum : f32 = biasLayer1[hiddenIndex]; let baseInputIndex = sampleIndex * inputDim; let baseWeightIndex = hiddenIndex * inputDim; for (var i : u32 = 0u; i < inputDim; i++) { sum += inputImages[baseInputIndex + i] * weightLayer1[baseWeightIndex + i]; } hiddenActivations[index] = max(sum, 0.0); }