142 lines
2.4 KiB
JavaScript
142 lines
2.4 KiB
JavaScript
export async function loadMNIST( device, sampleCount = 1000, preloader ) {
|
|
|
|
async function loadFile( path, progressStart, progressEnd, label ) {
|
|
|
|
const response = await fetch( path );
|
|
|
|
const contentLength = Number(
|
|
response.headers.get( "Content-Length" )
|
|
);
|
|
|
|
const totalMB = contentLength
|
|
? ( contentLength / ( 1024 * 1024 ) )
|
|
: 0;
|
|
|
|
const reader = response.body.getReader();
|
|
|
|
let receivedLength = 0;
|
|
let chunks = new Array();
|
|
|
|
while ( true ) {
|
|
|
|
const result = await reader.read();
|
|
|
|
if ( result.done ) break;
|
|
|
|
chunks.push( result.value );
|
|
receivedLength += result.value.length;
|
|
|
|
if ( preloader && contentLength ) {
|
|
|
|
const receivedMB = receivedLength / ( 1024 * 1024 );
|
|
const ratio = receivedLength / contentLength;
|
|
|
|
preloader.setText(
|
|
label +
|
|
" (" +
|
|
receivedMB.toFixed( 2 ) +
|
|
" MB / " +
|
|
totalMB.toFixed( 2 ) +
|
|
" MB)"
|
|
);
|
|
|
|
preloader.setProgress(
|
|
progressStart + ratio * ( progressEnd - progressStart )
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const buffer = new Uint8Array( receivedLength );
|
|
|
|
let offset = 0;
|
|
|
|
for ( let i = 0; i < chunks.length; i++ ) {
|
|
|
|
buffer.set( chunks[ i ], offset );
|
|
offset += chunks[ i ].length;
|
|
|
|
}
|
|
|
|
return buffer;
|
|
|
|
}
|
|
|
|
if ( preloader ) {
|
|
|
|
preloader.setText( "Loading MNIST images…" );
|
|
preloader.setProgress( 5 );
|
|
|
|
}
|
|
|
|
const imgRaw = await loadFile(
|
|
"./models/train-images-idx3-ubyte",
|
|
5,
|
|
40,
|
|
"Loading MNIST images"
|
|
);
|
|
|
|
if ( preloader ) {
|
|
|
|
preloader.setText( "Loading MNIST labels…" );
|
|
preloader.setProgress( 40 );
|
|
|
|
}
|
|
|
|
const lblRaw = await loadFile(
|
|
"./models/train-labels-idx1-ubyte",
|
|
40,
|
|
55,
|
|
"Loading MNIST labels"
|
|
);
|
|
|
|
const header = 16;
|
|
const fullCount = lblRaw.length - 8;
|
|
const count = Math.min( sampleCount, fullCount );
|
|
|
|
const images = new Float32Array( count * 28 * 28 );
|
|
const labels = new Uint32Array( count );
|
|
|
|
if ( preloader ) {
|
|
|
|
preloader.setText( "Decoding MNIST data…" );
|
|
preloader.setProgress( 55 );
|
|
|
|
}
|
|
|
|
for ( let i = 0; i < count; i++ ) {
|
|
|
|
labels[ i ] = lblRaw[ 8 + i ];
|
|
|
|
const src = header + i * 784;
|
|
const dst = i * 784;
|
|
|
|
for ( let j = 0; j < 784; j++ ) {
|
|
|
|
images[ dst + j ] = imgRaw[ src + j ] / 255;
|
|
|
|
}
|
|
|
|
if ( preloader && ( i % 64 === 0 ) ) {
|
|
|
|
preloader.setProgress(
|
|
55 + ( i / count ) * 45
|
|
);
|
|
|
|
await new Promise( requestAnimationFrame );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if ( preloader ) {
|
|
|
|
preloader.setProgress( 100 );
|
|
|
|
}
|
|
|
|
return { images, labels, count };
|
|
|
|
}
|