export async function loadMNIST( device, sampleCount = 1000, preloader ) { async function loadFile( path, progress ) { const res = await fetch( path ); const buffer = await res.arrayBuffer(); if ( preloader ) { preloader.setProgress( progress ); } return new Uint8Array( buffer ); } if ( preloader ) preloader.setText( "Loading MNIST images…" ); const imgRaw = await loadFile( "./models/train-images-idx3-ubyte", 40 ); if ( preloader ) preloader.setText( "Loading MNIST labels…" ); const lblRaw = await loadFile( "./models/train-labels-idx1-ubyte", 60 ); const header = 16; const fullCount = lblRaw.length - 8; const count = Math.min( sampleCount, fullCount ); const images = new Float32Array( count * 28 * 28 ); const labels = new Uint32Array( count ); for ( let i = 0; i < count; i++ ) { labels[ i ] = lblRaw[ 8 + i ]; const src = header + i * 784; const dst = i * 784; for ( let j = 0; j < 784; j++ ) { images[ dst + j ] = imgRaw[ src + j ] / 255; } if ( preloader && ( i % 64 === 0 ) ) { preloader.setProgress( 60 + ( i / count ) * 40 ); } } if ( preloader ) { preloader.setProgress( 100 ); } return { images, labels, count }; }