First commit.
This commit is contained in:
14
source/index_old.html
Normal file
14
source/index_old.html
Normal file
@@ -0,0 +1,14 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>WebGPU Training Demo</title>
|
||||
<base href="Classification/">
|
||||
<script type="module" src="./training.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebGPU Linear Regression Training</h1>
|
||||
<button id="trainBtn">Train Model</button>
|
||||
<pre id="log"></pre>
|
||||
</body>
|
||||
</html>
|
||||
40
source/loadMNIST.js
Normal file
40
source/loadMNIST.js
Normal file
@@ -0,0 +1,40 @@
|
||||
export async function loadBinary(url) {
|
||||
const response = await fetch(url);
|
||||
if (!response.ok) throw new Error("Failed to load " + url);
|
||||
const buffer = await response.arrayBuffer();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
// Loads MNIST 28x28 grayscale images + labels
|
||||
export async function loadMNIST() {
|
||||
const imgBuf = await loadBinary("/models/MNIST-2k-images.bin");
|
||||
const lblBuf = await loadBinary("/models/MNIST-2k-labels.bin");
|
||||
|
||||
const images = new Float32Array(imgBuf); // already normalized [0..1]
|
||||
const labels = new Uint8Array(lblBuf);
|
||||
|
||||
const sampleCount = labels.length;
|
||||
console.log("MNIST Loaded:", sampleCount, "samples");
|
||||
|
||||
return { images, labels, sampleCount };
|
||||
}
|
||||
|
||||
// Utility visualization helper for debugging
|
||||
export function showImage(img784, canvas) {
|
||||
const size = 28;
|
||||
canvas.width = size;
|
||||
canvas.height = size;
|
||||
|
||||
const ctx = canvas.getContext("2d");
|
||||
const imgData = ctx.createImageData(size, size);
|
||||
|
||||
for (let i = 0; i < size * size; i++) {
|
||||
const v = Math.floor(img784[i] * 255);
|
||||
imgData.data[i * 4 + 0] = v;
|
||||
imgData.data[i * 4 + 1] = v;
|
||||
imgData.data[i * 4 + 2] = v;
|
||||
imgData.data[i * 4 + 3] = 255;
|
||||
}
|
||||
|
||||
ctx.putImageData(imgData, 0, 0);
|
||||
}
|
||||
27
source/mnist_loader.js
Normal file
27
source/mnist_loader.js
Normal file
@@ -0,0 +1,27 @@
|
||||
export async function loadMNIST(device, sampleCount = 1000) {
|
||||
async function loadFile(path) {
|
||||
const res = await fetch(path);
|
||||
return new Uint8Array(await res.arrayBuffer());
|
||||
}
|
||||
|
||||
const imgRaw = await loadFile("../../models/train-images-idx3-ubyte");
|
||||
const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte");
|
||||
|
||||
const header = 16;
|
||||
const fullCount = lblRaw.length - 8;
|
||||
const count = Math.min(sampleCount, fullCount);
|
||||
|
||||
const images = new Float32Array(count * 28 * 28);
|
||||
const labels = new Uint32Array(count);
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
labels[i] = lblRaw[8 + i];
|
||||
const src = header + i * 784;
|
||||
const dst = i * 784;
|
||||
for (let j = 0; j < 784; j++) {
|
||||
images[dst + j] = imgRaw[src + j] / 255;
|
||||
}
|
||||
}
|
||||
|
||||
return { images, labels, count };
|
||||
}
|
||||
22
source/train_images.html
Normal file
22
source/train_images.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>WebGPU MNIST Training (Batch 64)</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; background:#111; color:#eee; }
|
||||
#preview { display:flex; gap:12px; margin-top:1rem; flex-wrap:wrap; }
|
||||
.item { text-align:center; font-size:14px; }
|
||||
canvas { image-rendering: pixelated; background:#000; border:1px solid #555; }
|
||||
button { font-size:16px; padding:8px 14px; cursor:pointer; }
|
||||
#log { white-space:pre; font-size:14px; max-height:240px; overflow-y:auto; background:#222; padding:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>WebGPU MNIST Training — Batch 64</h2>
|
||||
<button id="trainBtn">Train</button>
|
||||
<div id="log"></div>
|
||||
<div id="preview"></div>
|
||||
<script type="module" src="./training2.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
111
source/training.js
Normal file
111
source/training.js
Normal file
@@ -0,0 +1,111 @@
|
||||
import Shader from "../../framework/WebGpu.js";
|
||||
|
||||
const N = 32; // training samples
|
||||
const inputSize = 3;
|
||||
const outputSize = 1;
|
||||
const lr = 0.01;
|
||||
const epochs = 50;
|
||||
|
||||
function generateTrainingData() {
|
||||
const X = new Float32Array(N * inputSize);
|
||||
const y = new Float32Array(N);
|
||||
|
||||
for (let i = 0; i < N; i++) {
|
||||
const x1 = Math.random() * 2 - 1;
|
||||
const x2 = Math.random() * 2 - 1;
|
||||
const x3 = Math.random() * 2 - 1;
|
||||
|
||||
const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5;
|
||||
|
||||
X[i * 3 + 0] = x1;
|
||||
X[i * 3 + 1] = x2;
|
||||
X[i * 3 + 2] = x3;
|
||||
|
||||
y[i] = label;
|
||||
}
|
||||
return { X, y };
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!navigator.gpu) {
|
||||
alert("WebGPU not supported!");
|
||||
return;
|
||||
}
|
||||
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
const { X, y } = generateTrainingData();
|
||||
|
||||
// --- 1: Create shaders ---
|
||||
const forward = new Shader(device);
|
||||
await forward.setup("../../shaders/neural/forward.wgsl");
|
||||
|
||||
const lossBackward = new Shader(device);
|
||||
await lossBackward.setup("../../shaders/neural/loss_backward.wgsl");
|
||||
|
||||
const weightGrad = new Shader(device);
|
||||
await weightGrad.setup("../../shaders/neural/weight_grad.wgsl");
|
||||
|
||||
const sgdUpdate = new Shader(device);
|
||||
await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl");
|
||||
|
||||
// --- 2: Initialize model parameters ---
|
||||
let W = new Float32Array(inputSize).map(() => Math.random() * 0.1);
|
||||
let b = new Float32Array([0]);
|
||||
|
||||
forward.setVariable("input", X);
|
||||
forward.setVariable("W", W);
|
||||
forward.setVariable("b", b);
|
||||
forward.setVariable("N", N);
|
||||
forward.setVariable("y", y);
|
||||
|
||||
|
||||
|
||||
// link GPU buffers between shaders
|
||||
lossBackward.setBuffer("y", forward.getBuffer("y"));
|
||||
lossBackward.setBuffer("pred", forward.getBuffer("pred"));
|
||||
lossBackward.setVariable("N", N);
|
||||
|
||||
weightGrad.setBuffer("input", forward.getBuffer("input"));
|
||||
weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred"));
|
||||
weightGrad.setVariable("N", N);
|
||||
|
||||
sgdUpdate.setBuffer("W", forward.getBuffer("W"));
|
||||
sgdUpdate.setBuffer("b", forward.getBuffer("b"));
|
||||
sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW"));
|
||||
sgdUpdate.setBuffer("db", weightGrad.getBuffer("db"));
|
||||
sgdUpdate.setVariable("lr", lr);
|
||||
|
||||
const log = document.getElementById("log");
|
||||
|
||||
document.getElementById("trainBtn").onclick = async () => {
|
||||
for (let epoch = 0; epoch < epochs; epoch++) {
|
||||
await forward.execute(N);
|
||||
await lossBackward.execute(N);
|
||||
await weightGrad.execute(N);
|
||||
await sgdUpdate.execute(inputSize);
|
||||
|
||||
const lossArr = await lossBackward.debugBuffer("loss");
|
||||
const loss = lossArr.reduce((a, b) => a + b, 0) / N;
|
||||
|
||||
log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`;
|
||||
|
||||
// Debug learned weights every 5 epochs
|
||||
if (epoch % 5 === 0) {
|
||||
const w = await forward.debugBuffer("W");
|
||||
console.log(
|
||||
`Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const learnedW = await forward.debugBuffer("W");
|
||||
const learnedB = await forward.debugBuffer("b");
|
||||
|
||||
log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`;
|
||||
log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`;
|
||||
};
|
||||
}
|
||||
|
||||
main();
|
||||
223
source/training2.js
Normal file
223
source/training2.js
Normal file
@@ -0,0 +1,223 @@
|
||||
import Shader from "../../framework/WebGpu.js";
|
||||
import { loadMNIST } from "./mnist_loader.js";
|
||||
|
||||
const batchSize = 1000;
|
||||
const inputDimension = 28 * 28;
|
||||
const hiddenDimension = 128;
|
||||
const numberOfClasses = 10;
|
||||
const learningRateValue = 0.05;
|
||||
const numberOfEpochs = 20;
|
||||
|
||||
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||||
const context = canvas.getContext("2d");
|
||||
const image = context.createImageData(28, 28);
|
||||
|
||||
for (let i = 0; i < inputDimension; i++) {
|
||||
const grayscale = pixelData[i] * 255;
|
||||
const pixelIndex = i * 4;
|
||||
image.data[pixelIndex + 0] = grayscale;
|
||||
image.data[pixelIndex + 1] = grayscale;
|
||||
image.data[pixelIndex + 2] = grayscale;
|
||||
image.data[pixelIndex + 3] = 255;
|
||||
}
|
||||
|
||||
context.putImageData(image, 0, 0);
|
||||
canvas.nextLabel.textContent =
|
||||
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!navigator.gpu) {
|
||||
alert("WebGPU not supported");
|
||||
return;
|
||||
}
|
||||
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||||
const mnist = await loadMNIST(device, batchSize);
|
||||
const inputImages = mnist.images;
|
||||
const targetLabels = mnist.labels;
|
||||
|
||||
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||||
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||||
|
||||
//
|
||||
// === Create Shader Instances ===
|
||||
//
|
||||
const forwardDenseLayer1 = new Shader(device);
|
||||
await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl");
|
||||
|
||||
const forwardDenseLayer2 = new Shader(device);
|
||||
await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl");
|
||||
|
||||
const softmaxCrossEntropyBackward = new Shader(device);
|
||||
await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||||
|
||||
const gradientsDenseLayer2 = new Shader(device);
|
||||
await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl");
|
||||
|
||||
const gradientsDenseLayer1 = new Shader(device);
|
||||
await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl");
|
||||
|
||||
const sgdOptimizerUpdate = new Shader(device);
|
||||
await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl");
|
||||
|
||||
//
|
||||
// === Model Parameters ===
|
||||
//
|
||||
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension)
|
||||
.map(() => (Math.random() * Math.sqrt(2.0 / inputDimension)) - (Math.sqrt(2.0 / inputDimension) / 2));
|
||||
|
||||
|
||||
const biasLayer1 = new Float32Array(hiddenDimension);
|
||||
|
||||
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses)
|
||||
.map(() => (Math.random() * Math.sqrt(2.0 / hiddenDimension)) - (Math.sqrt(2.0 / hiddenDimension) / 2));
|
||||
|
||||
|
||||
const biasLayer2 = new Float32Array(numberOfClasses);
|
||||
|
||||
//
|
||||
// === Bind Buffers to Shaders ===
|
||||
//
|
||||
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||||
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||||
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||||
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||
|
||||
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||||
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||||
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||||
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||||
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||||
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||||
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||||
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||||
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||||
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||||
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||||
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||||
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||
|
||||
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||||
|
||||
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||||
|
||||
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||||
|
||||
//
|
||||
// === UI ===
|
||||
//
|
||||
const logOutput = document.getElementById("log");
|
||||
const previewContainer = document.getElementById("preview");
|
||||
const numberOfPreviewImages = 6;
|
||||
|
||||
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||||
() => Math.floor(Math.random() * batchSize));
|
||||
|
||||
const previewCanvases = previewIndices.map(() => {
|
||||
const container = document.createElement("div");
|
||||
container.className = "previewItem";
|
||||
const canvas = document.createElement("canvas");
|
||||
canvas.width = canvas.height = 28;
|
||||
const label = document.createElement("div");
|
||||
canvas.nextLabel = label;
|
||||
container.appendChild(canvas);
|
||||
container.appendChild(label);
|
||||
previewContainer.appendChild(container);
|
||||
return canvas;
|
||||
});
|
||||
|
||||
//
|
||||
// === Training ===
|
||||
//
|
||||
document.getElementById("trainBtn").onclick = async () => {
|
||||
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||||
|
||||
await forwardDenseLayer1.execute(batchSize);
|
||||
await forwardDenseLayer2.execute(batchSize);
|
||||
await softmaxCrossEntropyBackward.execute(batchSize);
|
||||
|
||||
const wgSize = 64;
|
||||
|
||||
// Grad Layer 2
|
||||
await gradientsDenseLayer2.execute(
|
||||
Math.ceil((batchSize * numberOfClasses + hiddenDimension * numberOfClasses) / wgSize)
|
||||
);
|
||||
|
||||
// Grad Layer 1
|
||||
await gradientsDenseLayer1.execute(
|
||||
Math.ceil((batchSize * hiddenDimension + inputDimension * hiddenDimension) / wgSize)
|
||||
);
|
||||
|
||||
// SGD
|
||||
await sgdOptimizerUpdate.execute(
|
||||
Math.ceil((weightLayer1.length + biasLayer1.length +
|
||||
weightLayer2.length + biasLayer2.length) / wgSize)
|
||||
);
|
||||
|
||||
// Loss
|
||||
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||||
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||||
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||||
|
||||
// Accuracy
|
||||
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||||
let correctCount = 0;
|
||||
|
||||
for (let i = 0; i < batchSize; i++) {
|
||||
let bestValue = -1e9;
|
||||
let bestClass = 0;
|
||||
for (let c = 0; c < numberOfClasses; c++) {
|
||||
const value = logits[i*numberOfClasses+c];
|
||||
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||
}
|
||||
if (bestClass === targetLabels[i]) correctCount++;
|
||||
}
|
||||
|
||||
const accuracy = correctCount / batchSize;
|
||||
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||||
|
||||
// Update preview images
|
||||
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||||
const index = previewIndices[k];
|
||||
const image = inputImages.subarray(
|
||||
index*inputDimension, (index+1)*inputDimension
|
||||
);
|
||||
|
||||
let bestValue = -1e9;
|
||||
let bestClass = 0;
|
||||
for (let c = 0; c < numberOfClasses; c++) {
|
||||
const value = logits[index*numberOfClasses+c];
|
||||
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||
}
|
||||
|
||||
drawMNISTPreview(
|
||||
previewCanvases[k],
|
||||
image,
|
||||
targetLabels[index],
|
||||
bestClass
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
main();
|
||||
212
source/training2_working.js
Normal file
212
source/training2_working.js
Normal file
@@ -0,0 +1,212 @@
|
||||
import Shader from "../../framework/WebGpu.js";
|
||||
import { loadMNIST } from "./mnist_loader.js";
|
||||
|
||||
const batchSize = 1000;
|
||||
const inputDimension = 28 * 28;
|
||||
const hiddenDimension = 128;
|
||||
const numberOfClasses = 10;
|
||||
const learningRateValue = 0.05;
|
||||
const numberOfEpochs = 20;
|
||||
|
||||
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||||
const context = canvas.getContext("2d");
|
||||
const image = context.createImageData(28, 28);
|
||||
|
||||
for (let i = 0; i < inputDimension; i++) {
|
||||
const grayscale = pixelData[i] * 255;
|
||||
const pixelIndex = i * 4;
|
||||
image.data[pixelIndex + 0] = grayscale;
|
||||
image.data[pixelIndex + 1] = grayscale;
|
||||
image.data[pixelIndex + 2] = grayscale;
|
||||
image.data[pixelIndex + 3] = 255;
|
||||
}
|
||||
|
||||
context.putImageData(image, 0, 0);
|
||||
canvas.nextLabel.textContent =
|
||||
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!navigator.gpu) {
|
||||
alert("WebGPU not supported");
|
||||
return;
|
||||
}
|
||||
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||||
const mnist = await loadMNIST(device, batchSize);
|
||||
const inputImages = mnist.images;
|
||||
const targetLabels = mnist.labels;
|
||||
|
||||
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||||
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||||
|
||||
//
|
||||
// === Create Shader Instances ===
|
||||
//
|
||||
const forwardDenseLayer1 = new Shader(device);
|
||||
await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl");
|
||||
|
||||
const forwardDenseLayer2 = new Shader(device);
|
||||
await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl");
|
||||
|
||||
const softmaxCrossEntropyBackward = new Shader(device);
|
||||
await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||||
|
||||
const gradientsDenseLayer2 = new Shader(device);
|
||||
await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl");
|
||||
|
||||
const gradientsDenseLayer1 = new Shader(device);
|
||||
await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl");
|
||||
|
||||
const sgdOptimizerUpdate = new Shader(device);
|
||||
await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl");
|
||||
|
||||
//
|
||||
// === Model Parameters ===
|
||||
//
|
||||
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension).map(() => (Math.random() - 0.5) * 0.05);
|
||||
const biasLayer1 = new Float32Array(hiddenDimension);
|
||||
|
||||
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses).map(() => (Math.random() - 0.5) * 0.05);
|
||||
const biasLayer2 = new Float32Array(numberOfClasses);
|
||||
|
||||
//
|
||||
// === Bind Buffers to Shaders ===
|
||||
//
|
||||
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||||
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||||
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||||
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||
|
||||
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||||
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||||
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||||
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||||
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||||
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||||
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||||
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||||
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||
|
||||
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||||
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||||
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||||
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||||
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||
|
||||
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||||
|
||||
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||||
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||||
|
||||
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||||
|
||||
//
|
||||
// === UI ===
|
||||
//
|
||||
const logOutput = document.getElementById("log");
|
||||
const previewContainer = document.getElementById("preview");
|
||||
const numberOfPreviewImages = 6;
|
||||
|
||||
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||||
() => Math.floor(Math.random() * batchSize));
|
||||
|
||||
const previewCanvases = previewIndices.map(() => {
|
||||
const container = document.createElement("div");
|
||||
container.className = "previewItem";
|
||||
const canvas = document.createElement("canvas");
|
||||
canvas.width = canvas.height = 28;
|
||||
const label = document.createElement("div");
|
||||
canvas.nextLabel = label;
|
||||
container.appendChild(canvas);
|
||||
container.appendChild(label);
|
||||
previewContainer.appendChild(container);
|
||||
return canvas;
|
||||
});
|
||||
|
||||
//
|
||||
// === Training ===
|
||||
//
|
||||
document.getElementById("trainBtn").onclick = async () => {
|
||||
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||||
|
||||
await forwardDenseLayer1.execute(batchSize);
|
||||
await forwardDenseLayer2.execute(batchSize);
|
||||
await softmaxCrossEntropyBackward.execute(batchSize);
|
||||
|
||||
await gradientsDenseLayer2.execute(
|
||||
hiddenDimension * numberOfClasses + numberOfClasses + batchSize * hiddenDimension
|
||||
);
|
||||
|
||||
await gradientsDenseLayer1.execute(
|
||||
inputDimension * hiddenDimension + hiddenDimension
|
||||
);
|
||||
|
||||
await sgdOptimizerUpdate.execute(
|
||||
weightLayer1.length + biasLayer1.length +
|
||||
weightLayer2.length + biasLayer2.length
|
||||
);
|
||||
|
||||
// Loss
|
||||
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||||
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||||
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||||
|
||||
// Accuracy
|
||||
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||||
let correctCount = 0;
|
||||
|
||||
for (let i = 0; i < batchSize; i++) {
|
||||
let bestValue = -1e9;
|
||||
let bestClass = 0;
|
||||
for (let c = 0; c < numberOfClasses; c++) {
|
||||
const value = logits[i*numberOfClasses+c];
|
||||
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||
}
|
||||
if (bestClass === targetLabels[i]) correctCount++;
|
||||
}
|
||||
|
||||
const accuracy = correctCount / batchSize;
|
||||
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||||
|
||||
// Update preview images
|
||||
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||||
const index = previewIndices[k];
|
||||
const image = inputImages.subarray(
|
||||
index*inputDimension, (index+1)*inputDimension
|
||||
);
|
||||
|
||||
let bestValue = -1e9;
|
||||
let bestClass = 0;
|
||||
for (let c = 0; c < numberOfClasses; c++) {
|
||||
const value = logits[index*numberOfClasses+c];
|
||||
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||
}
|
||||
|
||||
drawMNISTPreview(
|
||||
previewCanvases[k],
|
||||
image,
|
||||
targetLabels[index],
|
||||
bestClass
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
main();
|
||||
Reference in New Issue
Block a user