import Shader from "../../framework/WebGpu.js"; import { loadMNIST } from "./mnist_loader.js"; const batchSize = 1000; const inputDimension = 28 * 28; const hiddenDimension = 128; const numberOfClasses = 10; const learningRateValue = 0.05; const numberOfEpochs = 20; function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) { const context = canvas.getContext("2d"); const image = context.createImageData(28, 28); for (let i = 0; i < inputDimension; i++) { const grayscale = pixelData[i] * 255; const pixelIndex = i * 4; image.data[pixelIndex + 0] = grayscale; image.data[pixelIndex + 1] = grayscale; image.data[pixelIndex + 2] = grayscale; image.data[pixelIndex + 3] = 255; } context.putImageData(image, 0, 0); canvas.nextLabel.textContent = `GT: ${trueLabel} → Pred: ${predictedLabel}`; } async function main() { if (!navigator.gpu) { alert("WebGPU not supported"); return; } const adapter = await navigator.gpu.requestAdapter(); const device = await adapter.requestDevice(); // Load MNIST subset → inputImages & targetLabels are GPU buffers const mnist = await loadMNIST(device, batchSize); const inputImages = mnist.images; const targetLabels = mnist.labels; const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]); const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]); // // === Create Shader Instances === // const forwardDenseLayer1 = new Shader(device); await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl"); const forwardDenseLayer2 = new Shader(device); await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl"); const softmaxCrossEntropyBackward = new Shader(device); await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl"); const gradientsDenseLayer2 = new Shader(device); await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl"); const gradientsDenseLayer1 = new Shader(device); await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl"); const sgdOptimizerUpdate = new Shader(device); await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl"); // // === Model Parameters === // const weightLayer1 = new Float32Array(inputDimension * hiddenDimension).map(() => (Math.random() - 0.5) * 0.05); const biasLayer1 = new Float32Array(hiddenDimension); const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses).map(() => (Math.random() - 0.5) * 0.05); const biasLayer2 = new Float32Array(numberOfClasses); // // === Bind Buffers to Shaders === // forwardDenseLayer1.setVariable("inputImages", inputImages); forwardDenseLayer1.setVariable("weightLayer1", weightLayer1); forwardDenseLayer1.setVariable("biasLayer1", biasLayer1); forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); forwardDenseLayer2.setVariable("weightLayer2", weightLayer2); forwardDenseLayer2.setVariable("biasLayer2", biasLayer2); forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits")); softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels); softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform"); gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits")); //gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null); gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); //gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null); //gradientsDenseLayer2.setBuffer("gradientHidden", null); gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); gradientsDenseLayer1.setVariable("inputImages", inputImages); gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden")); //gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null); //gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null); gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1")); sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1")); sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1")); sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1")); sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2")); sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2")); sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2")); sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform"); // // === UI === // const logOutput = document.getElementById("log"); const previewContainer = document.getElementById("preview"); const numberOfPreviewImages = 6; const previewIndices = Array.from({ length: numberOfPreviewImages }, () => Math.floor(Math.random() * batchSize)); const previewCanvases = previewIndices.map(() => { const container = document.createElement("div"); container.className = "previewItem"; const canvas = document.createElement("canvas"); canvas.width = canvas.height = 28; const label = document.createElement("div"); canvas.nextLabel = label; container.appendChild(canvas); container.appendChild(label); previewContainer.appendChild(container); return canvas; }); // // === Training === // document.getElementById("trainBtn").onclick = async () => { for (let epoch = 0; epoch < numberOfEpochs; epoch++) { await forwardDenseLayer1.execute(batchSize); await forwardDenseLayer2.execute(batchSize); await softmaxCrossEntropyBackward.execute(batchSize); await gradientsDenseLayer2.execute( hiddenDimension * numberOfClasses + numberOfClasses + batchSize * hiddenDimension ); await gradientsDenseLayer1.execute( inputDimension * hiddenDimension + hiddenDimension ); await sgdOptimizerUpdate.execute( weightLayer1.length + biasLayer1.length + weightLayer2.length + biasLayer2.length ); // Loss const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss"); const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize; logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`; // Accuracy const logits = await forwardDenseLayer2.debugBuffer("outputLogits"); let correctCount = 0; for (let i = 0; i < batchSize; i++) { let bestValue = -1e9; let bestClass = 0; for (let c = 0; c < numberOfClasses; c++) { const value = logits[i*numberOfClasses+c]; if (value > bestValue) { bestValue = value; bestClass = c; } } if (bestClass === targetLabels[i]) correctCount++; } const accuracy = correctCount / batchSize; logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`; // Update preview images for (let k = 0; k < numberOfPreviewImages; k++) { const index = previewIndices[k]; const image = inputImages.subarray( index*inputDimension, (index+1)*inputDimension ); let bestValue = -1e9; let bestClass = 0; for (let c = 0; c < numberOfClasses; c++) { const value = logits[index*numberOfClasses+c]; if (value > bestValue) { bestValue = value; bestClass = c; } } drawMNISTPreview( previewCanvases[k], image, targetLabels[index], bestClass ); } } }; } main();