112 lines
3.3 KiB
JavaScript
112 lines
3.3 KiB
JavaScript
|
|
import Shader from "../../framework/WebGpu.js";
|
||
|
|
|
||
|
|
const N = 32; // training samples
|
||
|
|
const inputSize = 3;
|
||
|
|
const outputSize = 1;
|
||
|
|
const lr = 0.01;
|
||
|
|
const epochs = 50;
|
||
|
|
|
||
|
|
function generateTrainingData() {
|
||
|
|
const X = new Float32Array(N * inputSize);
|
||
|
|
const y = new Float32Array(N);
|
||
|
|
|
||
|
|
for (let i = 0; i < N; i++) {
|
||
|
|
const x1 = Math.random() * 2 - 1;
|
||
|
|
const x2 = Math.random() * 2 - 1;
|
||
|
|
const x3 = Math.random() * 2 - 1;
|
||
|
|
|
||
|
|
const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5;
|
||
|
|
|
||
|
|
X[i * 3 + 0] = x1;
|
||
|
|
X[i * 3 + 1] = x2;
|
||
|
|
X[i * 3 + 2] = x3;
|
||
|
|
|
||
|
|
y[i] = label;
|
||
|
|
}
|
||
|
|
return { X, y };
|
||
|
|
}
|
||
|
|
|
||
|
|
async function main() {
|
||
|
|
if (!navigator.gpu) {
|
||
|
|
alert("WebGPU not supported!");
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
const adapter = await navigator.gpu.requestAdapter();
|
||
|
|
const device = await adapter.requestDevice();
|
||
|
|
|
||
|
|
const { X, y } = generateTrainingData();
|
||
|
|
|
||
|
|
// --- 1: Create shaders ---
|
||
|
|
const forward = new Shader(device);
|
||
|
|
await forward.setup("../../shaders/neural/forward.wgsl");
|
||
|
|
|
||
|
|
const lossBackward = new Shader(device);
|
||
|
|
await lossBackward.setup("../../shaders/neural/loss_backward.wgsl");
|
||
|
|
|
||
|
|
const weightGrad = new Shader(device);
|
||
|
|
await weightGrad.setup("../../shaders/neural/weight_grad.wgsl");
|
||
|
|
|
||
|
|
const sgdUpdate = new Shader(device);
|
||
|
|
await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl");
|
||
|
|
|
||
|
|
// --- 2: Initialize model parameters ---
|
||
|
|
let W = new Float32Array(inputSize).map(() => Math.random() * 0.1);
|
||
|
|
let b = new Float32Array([0]);
|
||
|
|
|
||
|
|
forward.setVariable("input", X);
|
||
|
|
forward.setVariable("W", W);
|
||
|
|
forward.setVariable("b", b);
|
||
|
|
forward.setVariable("N", N);
|
||
|
|
forward.setVariable("y", y);
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
// link GPU buffers between shaders
|
||
|
|
lossBackward.setBuffer("y", forward.getBuffer("y"));
|
||
|
|
lossBackward.setBuffer("pred", forward.getBuffer("pred"));
|
||
|
|
lossBackward.setVariable("N", N);
|
||
|
|
|
||
|
|
weightGrad.setBuffer("input", forward.getBuffer("input"));
|
||
|
|
weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred"));
|
||
|
|
weightGrad.setVariable("N", N);
|
||
|
|
|
||
|
|
sgdUpdate.setBuffer("W", forward.getBuffer("W"));
|
||
|
|
sgdUpdate.setBuffer("b", forward.getBuffer("b"));
|
||
|
|
sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW"));
|
||
|
|
sgdUpdate.setBuffer("db", weightGrad.getBuffer("db"));
|
||
|
|
sgdUpdate.setVariable("lr", lr);
|
||
|
|
|
||
|
|
const log = document.getElementById("log");
|
||
|
|
|
||
|
|
document.getElementById("trainBtn").onclick = async () => {
|
||
|
|
for (let epoch = 0; epoch < epochs; epoch++) {
|
||
|
|
await forward.execute(N);
|
||
|
|
await lossBackward.execute(N);
|
||
|
|
await weightGrad.execute(N);
|
||
|
|
await sgdUpdate.execute(inputSize);
|
||
|
|
|
||
|
|
const lossArr = await lossBackward.debugBuffer("loss");
|
||
|
|
const loss = lossArr.reduce((a, b) => a + b, 0) / N;
|
||
|
|
|
||
|
|
log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`;
|
||
|
|
|
||
|
|
// Debug learned weights every 5 epochs
|
||
|
|
if (epoch % 5 === 0) {
|
||
|
|
const w = await forward.debugBuffer("W");
|
||
|
|
console.log(
|
||
|
|
`Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}`
|
||
|
|
);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const learnedW = await forward.debugBuffer("W");
|
||
|
|
const learnedB = await forward.debugBuffer("b");
|
||
|
|
|
||
|
|
log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`;
|
||
|
|
log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`;
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
main();
|