From c75eb60c5bf64ec0f41080cf61ab0745223ff0d4 Mon Sep 17 00:00:00 2001 From: kaj dijkstra Date: Fri, 2 Jan 2026 11:44:05 +0100 Subject: [PATCH] added preloader --- index.html | 48 ++++++++++++++++++++++++++++ mnist_loader.js | 83 ++++++++++++++++++++++++++++++++++++------------- training.js | 58 +++++++++++++++++++++++++++++++++- 3 files changed, 167 insertions(+), 22 deletions(-) diff --git a/index.html b/index.html index cec2867..3457249 100644 --- a/index.html +++ b/index.html @@ -10,9 +10,57 @@ canvas { image-rendering: pixelated; background:#000; border:1px solid #555; } button { font-size:16px; padding:8px 14px; cursor:pointer; } #log { white-space:pre; font-size:14px; max-height:240px; overflow-y:auto; background:#222; padding:8px; } + +#preloader { + position: fixed; + inset: 0; + background: #0b0b0b; + display: flex; + align-items: center; + justify-content: center; + z-index: 9999; +} + +#preloaderBox { + width: 320px; + text-align: center; +} + +#preloaderText { + margin-bottom: 10px; + font-size: 14px; + color: #ccc; +} + +#preloaderBarOuter { + width: 100%; + height: 8px; + background: #222; + border-radius: 4px; + overflow: hidden; +} + +#preloaderBarInner { + height: 100%; + width: 0%; + background: #4ade80; + transition: width 0.2s linear; +} + + + +
+
+
Loading MNIST…
+
+
+
+
+
+

WebGPU MNIST Training — Batch 64

diff --git a/mnist_loader.js b/mnist_loader.js index c12685e..4b41040 100644 --- a/mnist_loader.js +++ b/mnist_loader.js @@ -1,27 +1,68 @@ -export async function loadMNIST(device, sampleCount = 1000) { - async function loadFile(path) { - const res = await fetch(path); - return new Uint8Array(await res.arrayBuffer()); - } +export async function loadMNIST( device, sampleCount = 1000, preloader ) { - const imgRaw = await loadFile("./models/train-images-idx3-ubyte"); - const lblRaw = await loadFile("./models/train-labels-idx1-ubyte"); + async function loadFile( path, progress ) { - const header = 16; - const fullCount = lblRaw.length - 8; - const count = Math.min(sampleCount, fullCount); + const res = await fetch( path ); + const buffer = await res.arrayBuffer(); - const images = new Float32Array(count * 28 * 28); - const labels = new Uint32Array(count); + if ( preloader ) { - for (let i = 0; i < count; i++) { - labels[i] = lblRaw[8 + i]; - const src = header + i * 784; - const dst = i * 784; - for (let j = 0; j < 784; j++) { - images[dst + j] = imgRaw[src + j] / 255; - } - } + preloader.setProgress( progress ); + + } + + return new Uint8Array( buffer ); + + } + + if ( preloader ) preloader.setText( "Loading MNIST images…" ); + const imgRaw = await loadFile( + "./models/train-images-idx3-ubyte", + 40 + ); + + if ( preloader ) preloader.setText( "Loading MNIST labels…" ); + const lblRaw = await loadFile( + "./models/train-labels-idx1-ubyte", + 60 + ); + + const header = 16; + const fullCount = lblRaw.length - 8; + const count = Math.min( sampleCount, fullCount ); + + const images = new Float32Array( count * 28 * 28 ); + const labels = new Uint32Array( count ); + + for ( let i = 0; i < count; i++ ) { + + labels[ i ] = lblRaw[ 8 + i ]; + + const src = header + i * 784; + const dst = i * 784; + + for ( let j = 0; j < 784; j++ ) { + + images[ dst + j ] = imgRaw[ src + j ] / 255; + + } + + if ( preloader && ( i % 64 === 0 ) ) { + + preloader.setProgress( + 60 + ( i / count ) * 40 + ); + + } + + } + + if ( preloader ) { + + preloader.setProgress( 100 ); + + } + + return { images, labels, count }; - return { images, labels, count }; } diff --git a/training.js b/training.js index 89860ee..f98669d 100644 --- a/training.js +++ b/training.js @@ -8,6 +8,46 @@ const numberOfClasses = 10; const learningRateValue = 0.05; const numberOfEpochs = 20; +class Preloader { + + constructor() { + + this.root = document.getElementById( "preloader" ); + this.text = document.getElementById( "preloaderText" ); + this.bar = document.getElementById( "preloaderBarInner" ); + + } + + setText( value ) { + + this.text.textContent = value; + + } + + setProgress( value ) { + + this.bar.style.width = Math.min( 100, value ) + "%"; + + } + + done() { + + this.root.style.opacity = "0"; + this.root.style.pointerEvents = "none"; + + setTimeout( () => { + + this.root.remove(); + + }, 300 ); + + } + +} + +const preloader = new Preloader(); + + function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) { const context = canvas.getContext("2d"); const image = context.createImageData(28, 28); @@ -36,7 +76,23 @@ async function main() { const device = await adapter.requestDevice(); // Load MNIST subset → inputImages & targetLabels are GPU buffers - const mnist = await loadMNIST(device, batchSize); + // const mnist = await loadMNIST(device, batchSize); + + +preloader.setText( "Initializing WebGPU…" ); +preloader.setProgress( 10 ); + +const mnist = await loadMNIST( + device, + batchSize, + preloader +); + +preloader.setText( "Finalizing setup…" ); +preloader.setProgress( 100 ); +preloader.done(); + + const inputImages = mnist.images; const targetLabels = mnist.labels;