First commit.
This commit is contained in:
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
@@ -0,0 +1,39 @@
|
||||
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
|
||||
|
||||
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
|
||||
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
|
||||
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
|
||||
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
|
||||
|
||||
@group(0) @binding(8) var<uniform> learningRate : f32;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let index = id.x;
|
||||
|
||||
if (index < arrayLength(&weightLayer1)) {
|
||||
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let b1Index = index - arrayLength(&weightLayer1);
|
||||
if (b1Index < arrayLength(&biasLayer1)) {
|
||||
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
|
||||
if (w2Index < arrayLength(&weightLayer2)) {
|
||||
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let b2Index = w2Index - arrayLength(&weightLayer2);
|
||||
if (b2Index < arrayLength(&biasLayer2)) {
|
||||
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user