diff --git a/mnist_loader.js b/mnist_loader.js index 4b41040..b2c081a 100644 --- a/mnist_loader.js +++ b/mnist_loader.js @@ -1,30 +1,94 @@ export async function loadMNIST( device, sampleCount = 1000, preloader ) { - async function loadFile( path, progress ) { + async function loadFile( path, progressStart, progressEnd, label ) { - const res = await fetch( path ); - const buffer = await res.arrayBuffer(); + const response = await fetch( path ); - if ( preloader ) { + const contentLength = Number( + response.headers.get( "Content-Length" ) + ); - preloader.setProgress( progress ); + const totalMB = contentLength + ? ( contentLength / ( 1024 * 1024 ) ) + : 0; + + const reader = response.body.getReader(); + + let receivedLength = 0; + let chunks = new Array(); + + while ( true ) { + + const result = await reader.read(); + + if ( result.done ) break; + + chunks.push( result.value ); + receivedLength += result.value.length; + + if ( preloader && contentLength ) { + + const receivedMB = receivedLength / ( 1024 * 1024 ); + const ratio = receivedLength / contentLength; + + preloader.setText( + label + + " (" + + receivedMB.toFixed( 2 ) + + " MB / " + + totalMB.toFixed( 2 ) + + " MB)" + ); + + preloader.setProgress( + progressStart + ratio * ( progressEnd - progressStart ) + ); + + } } - return new Uint8Array( buffer ); + const buffer = new Uint8Array( receivedLength ); + + let offset = 0; + + for ( let i = 0; i < chunks.length; i++ ) { + + buffer.set( chunks[ i ], offset ); + offset += chunks[ i ].length; + + } + + return buffer; + + } + + if ( preloader ) { + + preloader.setText( "Loading MNIST images…" ); + preloader.setProgress( 5 ); } - if ( preloader ) preloader.setText( "Loading MNIST images…" ); const imgRaw = await loadFile( "./models/train-images-idx3-ubyte", - 40 + 5, + 40, + "Loading MNIST images" ); - if ( preloader ) preloader.setText( "Loading MNIST labels…" ); + if ( preloader ) { + + preloader.setText( "Loading MNIST labels…" ); + preloader.setProgress( 40 ); + + } + const lblRaw = await loadFile( "./models/train-labels-idx1-ubyte", - 60 + 40, + 55, + "Loading MNIST labels" ); const header = 16; @@ -34,6 +98,13 @@ export async function loadMNIST( device, sampleCount = 1000, preloader ) { const images = new Float32Array( count * 28 * 28 ); const labels = new Uint32Array( count ); + if ( preloader ) { + + preloader.setText( "Decoding MNIST data…" ); + preloader.setProgress( 55 ); + + } + for ( let i = 0; i < count; i++ ) { labels[ i ] = lblRaw[ 8 + i ]; @@ -50,9 +121,11 @@ export async function loadMNIST( device, sampleCount = 1000, preloader ) { if ( preloader && ( i % 64 === 0 ) ) { preloader.setProgress( - 60 + ( i / count ) * 40 + 55 + ( i / count ) * 45 ); + await new Promise( requestAnimationFrame ); + } }