First commit.
This commit is contained in:
33
shaders/neural/forward_dense_layer1.wgsl
Normal file
33
shaders/neural/forward_dense_layer1.wgsl
Normal file
@@ -0,0 +1,33 @@
|
||||
@group(0) @binding(0) var<storage, read> inputImages : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> weightLayer1 : array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> biasLayer1 : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> hiddenActivations : array<f32>;
|
||||
@group(0) @binding(4) var<uniform> layer1Sizes : vec3<u32>;
|
||||
// x = inputDim, y = hiddenDim, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
|
||||
let index = global_id.x;
|
||||
|
||||
let inputDim = layer1Sizes.x;
|
||||
let hiddenDim = layer1Sizes.y;
|
||||
let batchSize = layer1Sizes.z;
|
||||
|
||||
let totalElements = batchSize * hiddenDim;
|
||||
if (index >= totalElements) { return; }
|
||||
|
||||
let sampleIndex = index / hiddenDim;
|
||||
let hiddenIndex = index % hiddenDim;
|
||||
|
||||
var sum : f32 = biasLayer1[hiddenIndex];
|
||||
|
||||
let baseInputIndex = sampleIndex * inputDim;
|
||||
let baseWeightIndex = hiddenIndex * inputDim;
|
||||
|
||||
for (var i : u32 = 0u; i < inputDim; i++) {
|
||||
sum += inputImages[baseInputIndex + i] *
|
||||
weightLayer1[baseWeightIndex + i];
|
||||
}
|
||||
|
||||
hiddenActivations[index] = max(sum, 0.0);
|
||||
}
|
||||
33
shaders/neural/forward_dense_layer2.wgsl
Normal file
33
shaders/neural/forward_dense_layer2.wgsl
Normal file
@@ -0,0 +1,33 @@
|
||||
@group(0) @binding(0) var<storage, read> hiddenActivations : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> weightLayer2 : array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> biasLayer2 : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> outputLogits : array<f32>;
|
||||
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||||
// x = hiddenDim, y = numClasses, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||
let index = id.x;
|
||||
|
||||
let hiddenDim = layer2Sizes.x;
|
||||
let numClasses = layer2Sizes.y;
|
||||
let batchSize = layer2Sizes.z;
|
||||
|
||||
let total = batchSize * numClasses;
|
||||
if (index >= total) { return; }
|
||||
|
||||
let sampleIndex = index / numClasses;
|
||||
let classIndex = index % numClasses;
|
||||
|
||||
var sum : f32 = biasLayer2[classIndex];
|
||||
|
||||
let baseHiddenIndex = sampleIndex * hiddenDim;
|
||||
let baseWeightIndex = classIndex * hiddenDim;
|
||||
|
||||
for (var i : u32 = 0u; i < hiddenDim; i++) {
|
||||
sum += hiddenActivations[baseHiddenIndex + i] *
|
||||
weightLayer2[baseWeightIndex + i];
|
||||
}
|
||||
|
||||
outputLogits[index] = sum;
|
||||
}
|
||||
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
@@ -0,0 +1,85 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read> inputImages : array<f32>;
|
||||
// [batchSize * inputDimension]
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read> gradientHidden : array<f32>;
|
||||
// [batchSize * hiddenDimension]
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> gradientWeightLayer1 : array<f32>;
|
||||
// [inputDimension * hiddenDimension]
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> gradientBiasLayer1 : array<f32>;
|
||||
// [hiddenDimension]
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> layer1Sizes : vec3<u32>;
|
||||
// x = inputDimension, y = hiddenDimension, z = batchSize
|
||||
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||
|
||||
let index = id.x;
|
||||
|
||||
let inputDimension = layer1Sizes.x;
|
||||
let hiddenDimension = layer1Sizes.y;
|
||||
let batchSize = layer1Sizes.z;
|
||||
|
||||
// We launch:
|
||||
// inputDimension * hiddenDimension threads for weight gradients
|
||||
// + hiddenDimension threads for bias gradients
|
||||
let totalWeightElements = inputDimension * hiddenDimension;
|
||||
let totalBiasElements = hiddenDimension;
|
||||
let totalElements = totalWeightElements + totalBiasElements;
|
||||
|
||||
if (index >= totalElements) {
|
||||
return;
|
||||
}
|
||||
|
||||
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
|
||||
if (index < totalWeightElements) {
|
||||
|
||||
// Map flat index into (inputIndex, hiddenIndex)
|
||||
let hiddenIndex : u32 = index / inputDimension;
|
||||
let inputIndex : u32 = index % inputDimension;
|
||||
|
||||
var sum : f32 = 0.0;
|
||||
|
||||
// dW1[inputIndex, hiddenIndex] = sum over batch of
|
||||
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
|
||||
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||
|
||||
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
|
||||
|
||||
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
|
||||
|
||||
sum = sum + inputValue * hiddenGradient;
|
||||
|
||||
}
|
||||
|
||||
gradientWeightLayer1[ index ] = sum;
|
||||
return;
|
||||
}
|
||||
|
||||
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
|
||||
let biasIndex : u32 = index - totalWeightElements;
|
||||
|
||||
if (biasIndex < hiddenDimension) {
|
||||
|
||||
var sumBias : f32 = 0.0;
|
||||
|
||||
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
|
||||
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||
|
||||
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
|
||||
|
||||
sumBias = sumBias + hiddenGradient;
|
||||
|
||||
}
|
||||
|
||||
gradientBiasLayer1[ biasIndex ] = sumBias;
|
||||
}
|
||||
}
|
||||
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
@@ -0,0 +1,81 @@
|
||||
// gradients_dense_layer2.wgsl
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
|
||||
|
||||
@group(0) @binding(5)
|
||||
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
|
||||
|
||||
@group(0) @binding(6)
|
||||
var<uniform> layer2Sizes : vec3<u32>;
|
||||
// x = hiddenDim, y = numClasses, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||
|
||||
let globalIndex = id.x;
|
||||
|
||||
let hiddenDim = layer2Sizes.x;
|
||||
let numClasses = layer2Sizes.y;
|
||||
let batchSize = layer2Sizes.z;
|
||||
|
||||
//
|
||||
// (1) Compute gradients for W2 and b2
|
||||
//
|
||||
let totalWeights = hiddenDim * numClasses;
|
||||
|
||||
if (globalIndex < totalWeights) {
|
||||
let classIndex = globalIndex / hiddenDim;
|
||||
let hiddenIndex = globalIndex % hiddenDim;
|
||||
|
||||
var sum : f32 = 0.0;
|
||||
|
||||
for (var b : u32 = 0u; b < batchSize; b++) {
|
||||
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
|
||||
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
|
||||
sum += activationValue * gradLogitValue;
|
||||
}
|
||||
|
||||
gradientWeightLayer2[globalIndex] = sum;
|
||||
|
||||
if (hiddenIndex == 0u) {
|
||||
var biasSum : f32 = 0.0;
|
||||
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
|
||||
biasSum += gradientLogits[b2 * numClasses + classIndex];
|
||||
}
|
||||
gradientBiasLayer2[classIndex] = biasSum;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// (2) Compute gradientHidden for layer1
|
||||
//
|
||||
let dhIndex = globalIndex - totalWeights;
|
||||
|
||||
if (dhIndex < batchSize * hiddenDim) {
|
||||
let b = dhIndex / hiddenDim;
|
||||
let hiddenIndex = dhIndex % hiddenDim;
|
||||
|
||||
var sum2 : f32 = 0.0;
|
||||
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let gradLogitValue = gradientLogits[b * numClasses + c];
|
||||
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
|
||||
sum2 += gradLogitValue * weightValue;
|
||||
}
|
||||
|
||||
gradientHidden[dhIndex] = sum2;
|
||||
}
|
||||
}
|
||||
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
@@ -0,0 +1,39 @@
|
||||
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
|
||||
|
||||
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
|
||||
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
|
||||
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
|
||||
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
|
||||
|
||||
@group(0) @binding(8) var<uniform> learningRate : f32;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let index = id.x;
|
||||
|
||||
if (index < arrayLength(&weightLayer1)) {
|
||||
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let b1Index = index - arrayLength(&weightLayer1);
|
||||
if (b1Index < arrayLength(&biasLayer1)) {
|
||||
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
|
||||
if (w2Index < arrayLength(&weightLayer2)) {
|
||||
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
|
||||
let b2Index = w2Index - arrayLength(&weightLayer2);
|
||||
if (b2Index < arrayLength(&biasLayer2)) {
|
||||
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
|
||||
return;
|
||||
}
|
||||
}
|
||||
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
@@ -0,0 +1,39 @@
|
||||
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
|
||||
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
|
||||
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||||
// x = hiddenDim, y = numClasses, z = batchSize
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let sampleIndex = id.x;
|
||||
let batchSize = layer2Sizes.z;
|
||||
let numClasses = layer2Sizes.y;
|
||||
|
||||
if (sampleIndex >= batchSize) { return; }
|
||||
|
||||
let base = sampleIndex * numClasses;
|
||||
let label = targetLabels[sampleIndex];
|
||||
|
||||
var maxLogit : f32 = -1e9;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let v = outputLogits[base + c];
|
||||
if (v > maxLogit) { maxLogit = v; }
|
||||
}
|
||||
|
||||
var sumExp : f32 = 0.0;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
sumExp += exp(outputLogits[base + c] - maxLogit);
|
||||
}
|
||||
|
||||
var loss : f32 = 0.0;
|
||||
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
|
||||
let oneHot = f32(c == label);
|
||||
gradientLogits[base + c] = sm - oneHot;
|
||||
if (oneHot == 1.0) { loss = -log(sm); }
|
||||
}
|
||||
|
||||
crossEntropyLoss[sampleIndex] = loss;
|
||||
}
|
||||
Reference in New Issue
Block a user