Files
2025-11-19 10:42:46 +01:00

112 lines
3.3 KiB
JavaScript

import Shader from "../../framework/WebGpu.js";
const N = 32; // training samples
const inputSize = 3;
const outputSize = 1;
const lr = 0.01;
const epochs = 50;
function generateTrainingData() {
const X = new Float32Array(N * inputSize);
const y = new Float32Array(N);
for (let i = 0; i < N; i++) {
const x1 = Math.random() * 2 - 1;
const x2 = Math.random() * 2 - 1;
const x3 = Math.random() * 2 - 1;
const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5;
X[i * 3 + 0] = x1;
X[i * 3 + 1] = x2;
X[i * 3 + 2] = x3;
y[i] = label;
}
return { X, y };
}
async function main() {
if (!navigator.gpu) {
alert("WebGPU not supported!");
return;
}
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
const { X, y } = generateTrainingData();
// --- 1: Create shaders ---
const forward = new Shader(device);
await forward.setup("../../shaders/neural/forward.wgsl");
const lossBackward = new Shader(device);
await lossBackward.setup("../../shaders/neural/loss_backward.wgsl");
const weightGrad = new Shader(device);
await weightGrad.setup("../../shaders/neural/weight_grad.wgsl");
const sgdUpdate = new Shader(device);
await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl");
// --- 2: Initialize model parameters ---
let W = new Float32Array(inputSize).map(() => Math.random() * 0.1);
let b = new Float32Array([0]);
forward.setVariable("input", X);
forward.setVariable("W", W);
forward.setVariable("b", b);
forward.setVariable("N", N);
forward.setVariable("y", y);
// link GPU buffers between shaders
lossBackward.setBuffer("y", forward.getBuffer("y"));
lossBackward.setBuffer("pred", forward.getBuffer("pred"));
lossBackward.setVariable("N", N);
weightGrad.setBuffer("input", forward.getBuffer("input"));
weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred"));
weightGrad.setVariable("N", N);
sgdUpdate.setBuffer("W", forward.getBuffer("W"));
sgdUpdate.setBuffer("b", forward.getBuffer("b"));
sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW"));
sgdUpdate.setBuffer("db", weightGrad.getBuffer("db"));
sgdUpdate.setVariable("lr", lr);
const log = document.getElementById("log");
document.getElementById("trainBtn").onclick = async () => {
for (let epoch = 0; epoch < epochs; epoch++) {
await forward.execute(N);
await lossBackward.execute(N);
await weightGrad.execute(N);
await sgdUpdate.execute(inputSize);
const lossArr = await lossBackward.debugBuffer("loss");
const loss = lossArr.reduce((a, b) => a + b, 0) / N;
log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`;
// Debug learned weights every 5 epochs
if (epoch % 5 === 0) {
const w = await forward.debugBuffer("W");
console.log(
`Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}`
);
}
}
const learnedW = await forward.debugBuffer("W");
const learnedB = await forward.debugBuffer("b");
log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`;
log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`;
};
}
main();