40 lines
1.5 KiB
WebGPU Shading Language
40 lines
1.5 KiB
WebGPU Shading Language
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
|
|
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
|
|
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
|
|
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
|
|
|
|
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
|
|
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
|
|
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
|
|
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
|
|
|
|
@group(0) @binding(8) var<uniform> learningRate : f32;
|
|
|
|
@compute @workgroup_size(64)
|
|
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
|
let index = id.x;
|
|
|
|
if (index < arrayLength(&weightLayer1)) {
|
|
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
|
|
return;
|
|
}
|
|
|
|
let b1Index = index - arrayLength(&weightLayer1);
|
|
if (b1Index < arrayLength(&biasLayer1)) {
|
|
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
|
|
return;
|
|
}
|
|
|
|
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
|
|
if (w2Index < arrayLength(&weightLayer2)) {
|
|
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
|
|
return;
|
|
}
|
|
|
|
let b2Index = w2Index - arrayLength(&weightLayer2);
|
|
if (b2Index < arrayLength(&biasLayer2)) {
|
|
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
|
|
return;
|
|
}
|
|
}
|