export async function loadMNIST( device, sampleCount = 1000, preloader ) { async function loadFile( path, progressStart, progressEnd, label ) { const response = await fetch( path ); const contentLength = Number( response.headers.get( "Content-Length" ) ); const totalMB = contentLength ? ( contentLength / ( 1024 * 1024 ) ) : 0; const reader = response.body.getReader(); let receivedLength = 0; let chunks = new Array(); while ( true ) { const result = await reader.read(); if ( result.done ) break; chunks.push( result.value ); receivedLength += result.value.length; if ( preloader && contentLength ) { const receivedMB = receivedLength / ( 1024 * 1024 ); const ratio = receivedLength / contentLength; preloader.setText( label + " (" + receivedMB.toFixed( 2 ) + " MB / " + totalMB.toFixed( 2 ) + " MB)" ); preloader.setProgress( progressStart + ratio * ( progressEnd - progressStart ) ); } } const buffer = new Uint8Array( receivedLength ); let offset = 0; for ( let i = 0; i < chunks.length; i++ ) { buffer.set( chunks[ i ], offset ); offset += chunks[ i ].length; } return buffer; } if ( preloader ) { preloader.setText( "Loading MNIST images…" ); preloader.setProgress( 5 ); } const imgRaw = await loadFile( "./models/train-images-idx3-ubyte", 5, 40, "Loading MNIST images" ); if ( preloader ) { preloader.setText( "Loading MNIST labels…" ); preloader.setProgress( 40 ); } const lblRaw = await loadFile( "./models/train-labels-idx1-ubyte", 40, 55, "Loading MNIST labels" ); const header = 16; const fullCount = lblRaw.length - 8; const count = Math.min( sampleCount, fullCount ); const images = new Float32Array( count * 28 * 28 ); const labels = new Uint32Array( count ); if ( preloader ) { preloader.setText( "Decoding MNIST data…" ); preloader.setProgress( 55 ); } for ( let i = 0; i < count; i++ ) { labels[ i ] = lblRaw[ 8 + i ]; const src = header + i * 784; const dst = i * 784; for ( let j = 0; j < 784; j++ ) { images[ dst + j ] = imgRaw[ src + j ] / 255; } if ( preloader && ( i % 64 === 0 ) ) { preloader.setProgress( 55 + ( i / count ) * 45 ); await new Promise( requestAnimationFrame ); } } if ( preloader ) { preloader.setProgress( 100 ); } return { images, labels, count }; }