28 lines
856 B
JavaScript
28 lines
856 B
JavaScript
|
|
export async function loadMNIST(device, sampleCount = 1000) {
|
||
|
|
async function loadFile(path) {
|
||
|
|
const res = await fetch(path);
|
||
|
|
return new Uint8Array(await res.arrayBuffer());
|
||
|
|
}
|
||
|
|
|
||
|
|
const imgRaw = await loadFile("../../models/train-images-idx3-ubyte");
|
||
|
|
const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte");
|
||
|
|
|
||
|
|
const header = 16;
|
||
|
|
const fullCount = lblRaw.length - 8;
|
||
|
|
const count = Math.min(sampleCount, fullCount);
|
||
|
|
|
||
|
|
const images = new Float32Array(count * 28 * 28);
|
||
|
|
const labels = new Uint32Array(count);
|
||
|
|
|
||
|
|
for (let i = 0; i < count; i++) {
|
||
|
|
labels[i] = lblRaw[8 + i];
|
||
|
|
const src = header + i * 784;
|
||
|
|
const dst = i * 784;
|
||
|
|
for (let j = 0; j < 784; j++) {
|
||
|
|
images[dst + j] = imgRaw[src + j] / 255;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return { images, labels, count };
|
||
|
|
}
|