diff --git a/index.html b/index.html
index cec2867..3457249 100644
--- a/index.html
+++ b/index.html
@@ -10,9 +10,57 @@
canvas { image-rendering: pixelated; background:#000; border:1px solid #555; }
button { font-size:16px; padding:8px 14px; cursor:pointer; }
#log { white-space:pre; font-size:14px; max-height:240px; overflow-y:auto; background:#222; padding:8px; }
+
+#preloader {
+ position: fixed;
+ inset: 0;
+ background: #0b0b0b;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ z-index: 9999;
+}
+
+#preloaderBox {
+ width: 320px;
+ text-align: center;
+}
+
+#preloaderText {
+ margin-bottom: 10px;
+ font-size: 14px;
+ color: #ccc;
+}
+
+#preloaderBarOuter {
+ width: 100%;
+ height: 8px;
+ background: #222;
+ border-radius: 4px;
+ overflow: hidden;
+}
+
+#preloaderBarInner {
+ height: 100%;
+ width: 0%;
+ background: #4ade80;
+ transition: width 0.2s linear;
+}
+
+
+
+
+
WebGPU MNIST Training — Batch 64
diff --git a/mnist_loader.js b/mnist_loader.js
index c12685e..4b41040 100644
--- a/mnist_loader.js
+++ b/mnist_loader.js
@@ -1,27 +1,68 @@
-export async function loadMNIST(device, sampleCount = 1000) {
- async function loadFile(path) {
- const res = await fetch(path);
- return new Uint8Array(await res.arrayBuffer());
- }
+export async function loadMNIST( device, sampleCount = 1000, preloader ) {
- const imgRaw = await loadFile("./models/train-images-idx3-ubyte");
- const lblRaw = await loadFile("./models/train-labels-idx1-ubyte");
+ async function loadFile( path, progress ) {
- const header = 16;
- const fullCount = lblRaw.length - 8;
- const count = Math.min(sampleCount, fullCount);
+ const res = await fetch( path );
+ const buffer = await res.arrayBuffer();
- const images = new Float32Array(count * 28 * 28);
- const labels = new Uint32Array(count);
+ if ( preloader ) {
- for (let i = 0; i < count; i++) {
- labels[i] = lblRaw[8 + i];
- const src = header + i * 784;
- const dst = i * 784;
- for (let j = 0; j < 784; j++) {
- images[dst + j] = imgRaw[src + j] / 255;
- }
- }
+ preloader.setProgress( progress );
+
+ }
+
+ return new Uint8Array( buffer );
+
+ }
+
+ if ( preloader ) preloader.setText( "Loading MNIST images…" );
+ const imgRaw = await loadFile(
+ "./models/train-images-idx3-ubyte",
+ 40
+ );
+
+ if ( preloader ) preloader.setText( "Loading MNIST labels…" );
+ const lblRaw = await loadFile(
+ "./models/train-labels-idx1-ubyte",
+ 60
+ );
+
+ const header = 16;
+ const fullCount = lblRaw.length - 8;
+ const count = Math.min( sampleCount, fullCount );
+
+ const images = new Float32Array( count * 28 * 28 );
+ const labels = new Uint32Array( count );
+
+ for ( let i = 0; i < count; i++ ) {
+
+ labels[ i ] = lblRaw[ 8 + i ];
+
+ const src = header + i * 784;
+ const dst = i * 784;
+
+ for ( let j = 0; j < 784; j++ ) {
+
+ images[ dst + j ] = imgRaw[ src + j ] / 255;
+
+ }
+
+ if ( preloader && ( i % 64 === 0 ) ) {
+
+ preloader.setProgress(
+ 60 + ( i / count ) * 40
+ );
+
+ }
+
+ }
+
+ if ( preloader ) {
+
+ preloader.setProgress( 100 );
+
+ }
+
+ return { images, labels, count };
- return { images, labels, count };
}
diff --git a/training.js b/training.js
index 89860ee..f98669d 100644
--- a/training.js
+++ b/training.js
@@ -8,6 +8,46 @@ const numberOfClasses = 10;
const learningRateValue = 0.05;
const numberOfEpochs = 20;
+class Preloader {
+
+ constructor() {
+
+ this.root = document.getElementById( "preloader" );
+ this.text = document.getElementById( "preloaderText" );
+ this.bar = document.getElementById( "preloaderBarInner" );
+
+ }
+
+ setText( value ) {
+
+ this.text.textContent = value;
+
+ }
+
+ setProgress( value ) {
+
+ this.bar.style.width = Math.min( 100, value ) + "%";
+
+ }
+
+ done() {
+
+ this.root.style.opacity = "0";
+ this.root.style.pointerEvents = "none";
+
+ setTimeout( () => {
+
+ this.root.remove();
+
+ }, 300 );
+
+ }
+
+}
+
+const preloader = new Preloader();
+
+
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
const context = canvas.getContext("2d");
const image = context.createImageData(28, 28);
@@ -36,7 +76,23 @@ async function main() {
const device = await adapter.requestDevice();
// Load MNIST subset → inputImages & targetLabels are GPU buffers
- const mnist = await loadMNIST(device, batchSize);
+ // const mnist = await loadMNIST(device, batchSize);
+
+
+preloader.setText( "Initializing WebGPU…" );
+preloader.setProgress( 10 );
+
+const mnist = await loadMNIST(
+ device,
+ batchSize,
+ preloader
+);
+
+preloader.setText( "Finalizing setup…" );
+preloader.setProgress( 100 );
+preloader.done();
+
+
const inputImages = mnist.images;
const targetLabels = mnist.labels;