First commit.
This commit is contained in:
111
source/training.js
Normal file
111
source/training.js
Normal file
@@ -0,0 +1,111 @@
|
||||
import Shader from "../../framework/WebGpu.js";
|
||||
|
||||
const N = 32; // training samples
|
||||
const inputSize = 3;
|
||||
const outputSize = 1;
|
||||
const lr = 0.01;
|
||||
const epochs = 50;
|
||||
|
||||
function generateTrainingData() {
|
||||
const X = new Float32Array(N * inputSize);
|
||||
const y = new Float32Array(N);
|
||||
|
||||
for (let i = 0; i < N; i++) {
|
||||
const x1 = Math.random() * 2 - 1;
|
||||
const x2 = Math.random() * 2 - 1;
|
||||
const x3 = Math.random() * 2 - 1;
|
||||
|
||||
const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5;
|
||||
|
||||
X[i * 3 + 0] = x1;
|
||||
X[i * 3 + 1] = x2;
|
||||
X[i * 3 + 2] = x3;
|
||||
|
||||
y[i] = label;
|
||||
}
|
||||
return { X, y };
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!navigator.gpu) {
|
||||
alert("WebGPU not supported!");
|
||||
return;
|
||||
}
|
||||
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
const { X, y } = generateTrainingData();
|
||||
|
||||
// --- 1: Create shaders ---
|
||||
const forward = new Shader(device);
|
||||
await forward.setup("../../shaders/neural/forward.wgsl");
|
||||
|
||||
const lossBackward = new Shader(device);
|
||||
await lossBackward.setup("../../shaders/neural/loss_backward.wgsl");
|
||||
|
||||
const weightGrad = new Shader(device);
|
||||
await weightGrad.setup("../../shaders/neural/weight_grad.wgsl");
|
||||
|
||||
const sgdUpdate = new Shader(device);
|
||||
await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl");
|
||||
|
||||
// --- 2: Initialize model parameters ---
|
||||
let W = new Float32Array(inputSize).map(() => Math.random() * 0.1);
|
||||
let b = new Float32Array([0]);
|
||||
|
||||
forward.setVariable("input", X);
|
||||
forward.setVariable("W", W);
|
||||
forward.setVariable("b", b);
|
||||
forward.setVariable("N", N);
|
||||
forward.setVariable("y", y);
|
||||
|
||||
|
||||
|
||||
// link GPU buffers between shaders
|
||||
lossBackward.setBuffer("y", forward.getBuffer("y"));
|
||||
lossBackward.setBuffer("pred", forward.getBuffer("pred"));
|
||||
lossBackward.setVariable("N", N);
|
||||
|
||||
weightGrad.setBuffer("input", forward.getBuffer("input"));
|
||||
weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred"));
|
||||
weightGrad.setVariable("N", N);
|
||||
|
||||
sgdUpdate.setBuffer("W", forward.getBuffer("W"));
|
||||
sgdUpdate.setBuffer("b", forward.getBuffer("b"));
|
||||
sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW"));
|
||||
sgdUpdate.setBuffer("db", weightGrad.getBuffer("db"));
|
||||
sgdUpdate.setVariable("lr", lr);
|
||||
|
||||
const log = document.getElementById("log");
|
||||
|
||||
document.getElementById("trainBtn").onclick = async () => {
|
||||
for (let epoch = 0; epoch < epochs; epoch++) {
|
||||
await forward.execute(N);
|
||||
await lossBackward.execute(N);
|
||||
await weightGrad.execute(N);
|
||||
await sgdUpdate.execute(inputSize);
|
||||
|
||||
const lossArr = await lossBackward.debugBuffer("loss");
|
||||
const loss = lossArr.reduce((a, b) => a + b, 0) / N;
|
||||
|
||||
log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`;
|
||||
|
||||
// Debug learned weights every 5 epochs
|
||||
if (epoch % 5 === 0) {
|
||||
const w = await forward.debugBuffer("W");
|
||||
console.log(
|
||||
`Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const learnedW = await forward.debugBuffer("W");
|
||||
const learnedB = await forward.debugBuffer("b");
|
||||
|
||||
log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`;
|
||||
log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`;
|
||||
};
|
||||
}
|
||||
|
||||
main();
|
||||
Reference in New Issue
Block a user