Files
WebGPU-Neural-Network-Trainer/mnist_loader.js

142 lines
2.4 KiB
JavaScript
Raw Normal View History

2026-01-02 11:44:05 +01:00
export async function loadMNIST( device, sampleCount = 1000, preloader ) {
2026-01-04 12:51:57 +01:00
async function loadFile( path, progressStart, progressEnd, label ) {
2026-01-02 11:44:05 +01:00
2026-01-04 12:51:57 +01:00
const response = await fetch( path );
2026-01-02 11:44:05 +01:00
2026-01-04 12:51:57 +01:00
const contentLength = Number(
response.headers.get( "Content-Length" )
);
2026-01-02 11:44:05 +01:00
2026-01-04 12:51:57 +01:00
const totalMB = contentLength
? ( contentLength / ( 1024 * 1024 ) )
: 0;
const reader = response.body.getReader();
let receivedLength = 0;
let chunks = new Array();
while ( true ) {
const result = await reader.read();
if ( result.done ) break;
chunks.push( result.value );
receivedLength += result.value.length;
if ( preloader && contentLength ) {
const receivedMB = receivedLength / ( 1024 * 1024 );
const ratio = receivedLength / contentLength;
preloader.setText(
label +
" (" +
receivedMB.toFixed( 2 ) +
" MB / " +
totalMB.toFixed( 2 ) +
" MB)"
);
preloader.setProgress(
progressStart + ratio * ( progressEnd - progressStart )
);
}
}
const buffer = new Uint8Array( receivedLength );
let offset = 0;
for ( let i = 0; i < chunks.length; i++ ) {
buffer.set( chunks[ i ], offset );
offset += chunks[ i ].length;
2026-01-02 11:44:05 +01:00
}
2026-01-04 12:51:57 +01:00
return buffer;
}
if ( preloader ) {
preloader.setText( "Loading MNIST images…" );
preloader.setProgress( 5 );
2026-01-02 11:44:05 +01:00
}
const imgRaw = await loadFile(
"./models/train-images-idx3-ubyte",
2026-01-04 12:51:57 +01:00
5,
40,
"Loading MNIST images"
2026-01-02 11:44:05 +01:00
);
2026-01-04 12:51:57 +01:00
if ( preloader ) {
preloader.setText( "Loading MNIST labels…" );
preloader.setProgress( 40 );
}
2026-01-02 11:44:05 +01:00
const lblRaw = await loadFile(
"./models/train-labels-idx1-ubyte",
2026-01-04 12:51:57 +01:00
40,
55,
"Loading MNIST labels"
2026-01-02 11:44:05 +01:00
);
const header = 16;
const fullCount = lblRaw.length - 8;
const count = Math.min( sampleCount, fullCount );
const images = new Float32Array( count * 28 * 28 );
const labels = new Uint32Array( count );
2026-01-04 12:51:57 +01:00
if ( preloader ) {
preloader.setText( "Decoding MNIST data…" );
preloader.setProgress( 55 );
}
2026-01-02 11:44:05 +01:00
for ( let i = 0; i < count; i++ ) {
labels[ i ] = lblRaw[ 8 + i ];
const src = header + i * 784;
const dst = i * 784;
for ( let j = 0; j < 784; j++ ) {
images[ dst + j ] = imgRaw[ src + j ] / 255;
}
if ( preloader && ( i % 64 === 0 ) ) {
preloader.setProgress(
2026-01-04 12:51:57 +01:00
55 + ( i / count ) * 45
2026-01-02 11:44:05 +01:00
);
2026-01-04 12:51:57 +01:00
await new Promise( requestAnimationFrame );
2026-01-02 11:44:05 +01:00
}
}
if ( preloader ) {
preloader.setProgress( 100 );
}
return { images, labels, count };
2025-11-19 10:42:46 +01:00
}