First commit.
This commit is contained in:
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
@@ -0,0 +1,85 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read> inputImages : array<f32>;
|
||||
// [batchSize * inputDimension]
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read> gradientHidden : array<f32>;
|
||||
// [batchSize * hiddenDimension]
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> gradientWeightLayer1 : array<f32>;
|
||||
// [inputDimension * hiddenDimension]
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> gradientBiasLayer1 : array<f32>;
|
||||
// [hiddenDimension]
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> layer1Sizes : vec3<u32>;
|
||||
// x = inputDimension, y = hiddenDimension, z = batchSize
|
||||
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||
|
||||
let index = id.x;
|
||||
|
||||
let inputDimension = layer1Sizes.x;
|
||||
let hiddenDimension = layer1Sizes.y;
|
||||
let batchSize = layer1Sizes.z;
|
||||
|
||||
// We launch:
|
||||
// inputDimension * hiddenDimension threads for weight gradients
|
||||
// + hiddenDimension threads for bias gradients
|
||||
let totalWeightElements = inputDimension * hiddenDimension;
|
||||
let totalBiasElements = hiddenDimension;
|
||||
let totalElements = totalWeightElements + totalBiasElements;
|
||||
|
||||
if (index >= totalElements) {
|
||||
return;
|
||||
}
|
||||
|
||||
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
|
||||
if (index < totalWeightElements) {
|
||||
|
||||
// Map flat index into (inputIndex, hiddenIndex)
|
||||
let hiddenIndex : u32 = index / inputDimension;
|
||||
let inputIndex : u32 = index % inputDimension;
|
||||
|
||||
var sum : f32 = 0.0;
|
||||
|
||||
// dW1[inputIndex, hiddenIndex] = sum over batch of
|
||||
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
|
||||
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||
|
||||
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
|
||||
|
||||
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
|
||||
|
||||
sum = sum + inputValue * hiddenGradient;
|
||||
|
||||
}
|
||||
|
||||
gradientWeightLayer1[ index ] = sum;
|
||||
return;
|
||||
}
|
||||
|
||||
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
|
||||
let biasIndex : u32 = index - totalWeightElements;
|
||||
|
||||
if (biasIndex < hiddenDimension) {
|
||||
|
||||
var sumBias : f32 = 0.0;
|
||||
|
||||
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
|
||||
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||
|
||||
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
|
||||
|
||||
sumBias = sumBias + hiddenGradient;
|
||||
|
||||
}
|
||||
|
||||
gradientBiasLayer1[ biasIndex ] = sumBias;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user