import Shader from "../../framework/WebGpu.js"; const N = 32; // training samples const inputSize = 3; const outputSize = 1; const lr = 0.01; const epochs = 50; function generateTrainingData() { const X = new Float32Array(N * inputSize); const y = new Float32Array(N); for (let i = 0; i < N; i++) { const x1 = Math.random() * 2 - 1; const x2 = Math.random() * 2 - 1; const x3 = Math.random() * 2 - 1; const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5; X[i * 3 + 0] = x1; X[i * 3 + 1] = x2; X[i * 3 + 2] = x3; y[i] = label; } return { X, y }; } async function main() { if (!navigator.gpu) { alert("WebGPU not supported!"); return; } const adapter = await navigator.gpu.requestAdapter(); const device = await adapter.requestDevice(); const { X, y } = generateTrainingData(); // --- 1: Create shaders --- const forward = new Shader(device); await forward.setup("../../shaders/neural/forward.wgsl"); const lossBackward = new Shader(device); await lossBackward.setup("../../shaders/neural/loss_backward.wgsl"); const weightGrad = new Shader(device); await weightGrad.setup("../../shaders/neural/weight_grad.wgsl"); const sgdUpdate = new Shader(device); await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl"); // --- 2: Initialize model parameters --- let W = new Float32Array(inputSize).map(() => Math.random() * 0.1); let b = new Float32Array([0]); forward.setVariable("input", X); forward.setVariable("W", W); forward.setVariable("b", b); forward.setVariable("N", N); forward.setVariable("y", y); // link GPU buffers between shaders lossBackward.setBuffer("y", forward.getBuffer("y")); lossBackward.setBuffer("pred", forward.getBuffer("pred")); lossBackward.setVariable("N", N); weightGrad.setBuffer("input", forward.getBuffer("input")); weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred")); weightGrad.setVariable("N", N); sgdUpdate.setBuffer("W", forward.getBuffer("W")); sgdUpdate.setBuffer("b", forward.getBuffer("b")); sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW")); sgdUpdate.setBuffer("db", weightGrad.getBuffer("db")); sgdUpdate.setVariable("lr", lr); const log = document.getElementById("log"); document.getElementById("trainBtn").onclick = async () => { for (let epoch = 0; epoch < epochs; epoch++) { await forward.execute(N); await lossBackward.execute(N); await weightGrad.execute(N); await sgdUpdate.execute(inputSize); const lossArr = await lossBackward.debugBuffer("loss"); const loss = lossArr.reduce((a, b) => a + b, 0) / N; log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`; // Debug learned weights every 5 epochs if (epoch % 5 === 0) { const w = await forward.debugBuffer("W"); console.log( `Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}` ); } } const learnedW = await forward.debugBuffer("W"); const learnedB = await forward.debugBuffer("b"); log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`; log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`; }; } main();