213 lines
8.9 KiB
JavaScript
213 lines
8.9 KiB
JavaScript
|
|
import Shader from "../../framework/WebGpu.js";
|
||
|
|
import { loadMNIST } from "./mnist_loader.js";
|
||
|
|
|
||
|
|
const batchSize = 1000;
|
||
|
|
const inputDimension = 28 * 28;
|
||
|
|
const hiddenDimension = 128;
|
||
|
|
const numberOfClasses = 10;
|
||
|
|
const learningRateValue = 0.05;
|
||
|
|
const numberOfEpochs = 20;
|
||
|
|
|
||
|
|
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||
|
|
const context = canvas.getContext("2d");
|
||
|
|
const image = context.createImageData(28, 28);
|
||
|
|
|
||
|
|
for (let i = 0; i < inputDimension; i++) {
|
||
|
|
const grayscale = pixelData[i] * 255;
|
||
|
|
const pixelIndex = i * 4;
|
||
|
|
image.data[pixelIndex + 0] = grayscale;
|
||
|
|
image.data[pixelIndex + 1] = grayscale;
|
||
|
|
image.data[pixelIndex + 2] = grayscale;
|
||
|
|
image.data[pixelIndex + 3] = 255;
|
||
|
|
}
|
||
|
|
|
||
|
|
context.putImageData(image, 0, 0);
|
||
|
|
canvas.nextLabel.textContent =
|
||
|
|
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||
|
|
}
|
||
|
|
|
||
|
|
async function main() {
|
||
|
|
if (!navigator.gpu) {
|
||
|
|
alert("WebGPU not supported");
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
const adapter = await navigator.gpu.requestAdapter();
|
||
|
|
const device = await adapter.requestDevice();
|
||
|
|
|
||
|
|
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||
|
|
const mnist = await loadMNIST(device, batchSize);
|
||
|
|
const inputImages = mnist.images;
|
||
|
|
const targetLabels = mnist.labels;
|
||
|
|
|
||
|
|
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||
|
|
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||
|
|
|
||
|
|
//
|
||
|
|
// === Create Shader Instances ===
|
||
|
|
//
|
||
|
|
const forwardDenseLayer1 = new Shader(device);
|
||
|
|
await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl");
|
||
|
|
|
||
|
|
const forwardDenseLayer2 = new Shader(device);
|
||
|
|
await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl");
|
||
|
|
|
||
|
|
const softmaxCrossEntropyBackward = new Shader(device);
|
||
|
|
await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||
|
|
|
||
|
|
const gradientsDenseLayer2 = new Shader(device);
|
||
|
|
await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl");
|
||
|
|
|
||
|
|
const gradientsDenseLayer1 = new Shader(device);
|
||
|
|
await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl");
|
||
|
|
|
||
|
|
const sgdOptimizerUpdate = new Shader(device);
|
||
|
|
await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl");
|
||
|
|
|
||
|
|
//
|
||
|
|
// === Model Parameters ===
|
||
|
|
//
|
||
|
|
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension).map(() => (Math.random() - 0.5) * 0.05);
|
||
|
|
const biasLayer1 = new Float32Array(hiddenDimension);
|
||
|
|
|
||
|
|
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses).map(() => (Math.random() - 0.5) * 0.05);
|
||
|
|
const biasLayer2 = new Float32Array(numberOfClasses);
|
||
|
|
|
||
|
|
//
|
||
|
|
// === Bind Buffers to Shaders ===
|
||
|
|
//
|
||
|
|
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||
|
|
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||
|
|
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||
|
|
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||
|
|
|
||
|
|
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||
|
|
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||
|
|
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||
|
|
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||
|
|
|
||
|
|
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||
|
|
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||
|
|
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||
|
|
|
||
|
|
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||
|
|
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||
|
|
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||
|
|
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||
|
|
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||
|
|
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||
|
|
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||
|
|
|
||
|
|
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||
|
|
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||
|
|
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||
|
|
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||
|
|
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||
|
|
|
||
|
|
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||
|
|
|
||
|
|
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||
|
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||
|
|
|
||
|
|
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||
|
|
|
||
|
|
//
|
||
|
|
// === UI ===
|
||
|
|
//
|
||
|
|
const logOutput = document.getElementById("log");
|
||
|
|
const previewContainer = document.getElementById("preview");
|
||
|
|
const numberOfPreviewImages = 6;
|
||
|
|
|
||
|
|
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||
|
|
() => Math.floor(Math.random() * batchSize));
|
||
|
|
|
||
|
|
const previewCanvases = previewIndices.map(() => {
|
||
|
|
const container = document.createElement("div");
|
||
|
|
container.className = "previewItem";
|
||
|
|
const canvas = document.createElement("canvas");
|
||
|
|
canvas.width = canvas.height = 28;
|
||
|
|
const label = document.createElement("div");
|
||
|
|
canvas.nextLabel = label;
|
||
|
|
container.appendChild(canvas);
|
||
|
|
container.appendChild(label);
|
||
|
|
previewContainer.appendChild(container);
|
||
|
|
return canvas;
|
||
|
|
});
|
||
|
|
|
||
|
|
//
|
||
|
|
// === Training ===
|
||
|
|
//
|
||
|
|
document.getElementById("trainBtn").onclick = async () => {
|
||
|
|
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||
|
|
|
||
|
|
await forwardDenseLayer1.execute(batchSize);
|
||
|
|
await forwardDenseLayer2.execute(batchSize);
|
||
|
|
await softmaxCrossEntropyBackward.execute(batchSize);
|
||
|
|
|
||
|
|
await gradientsDenseLayer2.execute(
|
||
|
|
hiddenDimension * numberOfClasses + numberOfClasses + batchSize * hiddenDimension
|
||
|
|
);
|
||
|
|
|
||
|
|
await gradientsDenseLayer1.execute(
|
||
|
|
inputDimension * hiddenDimension + hiddenDimension
|
||
|
|
);
|
||
|
|
|
||
|
|
await sgdOptimizerUpdate.execute(
|
||
|
|
weightLayer1.length + biasLayer1.length +
|
||
|
|
weightLayer2.length + biasLayer2.length
|
||
|
|
);
|
||
|
|
|
||
|
|
// Loss
|
||
|
|
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||
|
|
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||
|
|
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||
|
|
|
||
|
|
// Accuracy
|
||
|
|
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||
|
|
let correctCount = 0;
|
||
|
|
|
||
|
|
for (let i = 0; i < batchSize; i++) {
|
||
|
|
let bestValue = -1e9;
|
||
|
|
let bestClass = 0;
|
||
|
|
for (let c = 0; c < numberOfClasses; c++) {
|
||
|
|
const value = logits[i*numberOfClasses+c];
|
||
|
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||
|
|
}
|
||
|
|
if (bestClass === targetLabels[i]) correctCount++;
|
||
|
|
}
|
||
|
|
|
||
|
|
const accuracy = correctCount / batchSize;
|
||
|
|
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||
|
|
|
||
|
|
// Update preview images
|
||
|
|
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||
|
|
const index = previewIndices[k];
|
||
|
|
const image = inputImages.subarray(
|
||
|
|
index*inputDimension, (index+1)*inputDimension
|
||
|
|
);
|
||
|
|
|
||
|
|
let bestValue = -1e9;
|
||
|
|
let bestClass = 0;
|
||
|
|
for (let c = 0; c < numberOfClasses; c++) {
|
||
|
|
const value = logits[index*numberOfClasses+c];
|
||
|
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||
|
|
}
|
||
|
|
|
||
|
|
drawMNISTPreview(
|
||
|
|
previewCanvases[k],
|
||
|
|
image,
|
||
|
|
targetLabels[index],
|
||
|
|
bestClass
|
||
|
|
);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
main();
|