First commit.
This commit is contained in:
83
framework/Camera.js
Normal file
83
framework/Camera.js
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import Vector3 from "./Vector3.js";
|
||||||
|
import Matrix4 from "./Matrix4.js";
|
||||||
|
|
||||||
|
export default class Camera {
|
||||||
|
|
||||||
|
eye = new Vector3();
|
||||||
|
target = new Vector3();
|
||||||
|
up = new Vector3( 0, 1, 0 );
|
||||||
|
|
||||||
|
yaw = 0;
|
||||||
|
pitch = 0;
|
||||||
|
fovRadians = Math.PI / 4;
|
||||||
|
near = 0.1;
|
||||||
|
far = 3000.0;
|
||||||
|
distance = 10;
|
||||||
|
viewMatrix = new Float32Array( 16 );
|
||||||
|
|
||||||
|
constructor( eye = [0, 0, 5], target = [0, 0, 0], up = [0, 1, 0] ) {
|
||||||
|
|
||||||
|
this.eye = new Vector3( ...eye );
|
||||||
|
this.target = new Vector3( ...target );
|
||||||
|
this.up = new Vector3( ...up );
|
||||||
|
|
||||||
|
this.distance = Vector3.subtract( this.eye, this.target ).length();
|
||||||
|
|
||||||
|
this.viewMatrix = Matrix4.lookAt( this.eye, this.target, this.up );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
update() {
|
||||||
|
|
||||||
|
const x = this.distance * Math.cos( this.pitch ) * Math.sin( this.yaw );
|
||||||
|
const y = this.distance * Math.sin( this.pitch );
|
||||||
|
const z = this.distance * Math.cos( this.pitch ) * Math.cos( this.yaw );
|
||||||
|
|
||||||
|
this.eye = new Vector3(
|
||||||
|
x + this.target.x,
|
||||||
|
y + this.target.y,
|
||||||
|
z + this.target.z
|
||||||
|
);
|
||||||
|
|
||||||
|
this.viewMatrix = Matrix4.lookAt( this.eye, this.target, this.up );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
getViewMatrix() {
|
||||||
|
|
||||||
|
return this.viewMatrix;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
rotate( deltaYaw, deltaPitch ) {
|
||||||
|
|
||||||
|
this.yaw += deltaYaw;
|
||||||
|
this.pitch -= deltaPitch;
|
||||||
|
|
||||||
|
const maxPitch = Math.PI / 2 - 0.01;
|
||||||
|
|
||||||
|
if ( this.pitch > maxPitch ) this.pitch = maxPitch;
|
||||||
|
if ( this.pitch < -maxPitch ) this.pitch = -maxPitch;
|
||||||
|
|
||||||
|
this.update();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
zoom( delta ) {
|
||||||
|
|
||||||
|
this.distance += delta * 1;
|
||||||
|
|
||||||
|
if ( this.distance < 0.1 ) this.distance = 0.1;
|
||||||
|
|
||||||
|
this.update();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
setTarget( target ) {
|
||||||
|
|
||||||
|
this.target = new Vector3( ...target );
|
||||||
|
|
||||||
|
this.update();
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
125
framework/Matrix4.js
Normal file
125
framework/Matrix4.js
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import Vector3 from "./Vector3.js";
|
||||||
|
|
||||||
|
export default class Matrix4 {
|
||||||
|
|
||||||
|
static lookAt( eye, target, up ) {
|
||||||
|
|
||||||
|
const zAxis = Vector3.normalize( Vector3.subtract( eye, target ) );
|
||||||
|
|
||||||
|
const xAxis = Vector3.normalize( Vector3.cross( up, zAxis ) );
|
||||||
|
|
||||||
|
const yAxis = Vector3.cross( zAxis, xAxis );
|
||||||
|
|
||||||
|
return new Float32Array([
|
||||||
|
xAxis.x, yAxis.x, zAxis.x, 0,
|
||||||
|
xAxis.y, yAxis.y, zAxis.y, 0,
|
||||||
|
xAxis.z, yAxis.z, zAxis.z, 0,
|
||||||
|
-Vector3.dot( xAxis, eye ), -Vector3.dot( yAxis, eye ), -Vector3.dot( zAxis, eye ), 1,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static getColumn( matrix, index ) {
|
||||||
|
const i = index * 4;
|
||||||
|
|
||||||
|
return new Vector3(
|
||||||
|
matrix[ i + 0 ],
|
||||||
|
matrix[ i + 1 ],
|
||||||
|
matrix[ i + 2 ]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static createProjectionMatrix( camera, canvas ) {
|
||||||
|
return Matrix4.perspective(
|
||||||
|
camera.fovRadians,
|
||||||
|
canvas.width / canvas.height,
|
||||||
|
camera.near,
|
||||||
|
camera.far
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static invert( m ) {
|
||||||
|
const out = new Float32Array(16);
|
||||||
|
|
||||||
|
const m00 = m[0], m01 = m[1], m02 = m[2], m03 = m[3];
|
||||||
|
const m10 = m[4], m11 = m[5], m12 = m[6], m13 = m[7];
|
||||||
|
const m20 = m[8], m21 = m[9], m22 = m[10], m23 = m[11];
|
||||||
|
const m30 = m[12], m31 = m[13], m32 = m[14], m33 = m[15];
|
||||||
|
|
||||||
|
const a0 = m00 * m11 - m01 * m10;
|
||||||
|
const a1 = m00 * m12 - m02 * m10;
|
||||||
|
const a2 = m00 * m13 - m03 * m10;
|
||||||
|
const a3 = m01 * m12 - m02 * m11;
|
||||||
|
const a4 = m01 * m13 - m03 * m11;
|
||||||
|
const a5 = m02 * m13 - m03 * m12;
|
||||||
|
|
||||||
|
const b0 = m20 * m31 - m21 * m30;
|
||||||
|
const b1 = m20 * m32 - m22 * m30;
|
||||||
|
const b2 = m20 * m33 - m23 * m30;
|
||||||
|
const b3 = m21 * m32 - m22 * m31;
|
||||||
|
const b4 = m21 * m33 - m23 * m31;
|
||||||
|
const b5 = m22 * m33 - m23 * m32;
|
||||||
|
|
||||||
|
const det = a0 * b5 - a1 * b4 + a2 * b3 + a3 * b2 - a4 * b1 + a5 * b0;
|
||||||
|
|
||||||
|
if (det === 0) return null;
|
||||||
|
|
||||||
|
const invDet = 1 / det;
|
||||||
|
|
||||||
|
out[0] = ( m11 * b5 - m12 * b4 + m13 * b3) * invDet;
|
||||||
|
out[1] = (-m01 * b5 + m02 * b4 - m03 * b3) * invDet;
|
||||||
|
out[2] = ( m31 * a5 - m32 * a4 + m33 * a3) * invDet;
|
||||||
|
out[3] = (-m21 * a5 + m22 * a4 - m23 * a3) * invDet;
|
||||||
|
|
||||||
|
out[4] = (-m10 * b5 + m12 * b2 - m13 * b1) * invDet;
|
||||||
|
out[5] = ( m00 * b5 - m02 * b2 + m03 * b1) * invDet;
|
||||||
|
out[6] = (-m30 * a5 + m32 * a2 - m33 * a1) * invDet;
|
||||||
|
out[7] = ( m20 * a5 - m22 * a2 + m23 * a1) * invDet;
|
||||||
|
|
||||||
|
out[8] = ( m10 * b4 - m11 * b2 + m13 * b0) * invDet;
|
||||||
|
out[9] = (-m00 * b4 + m01 * b2 - m03 * b0) * invDet;
|
||||||
|
out[10] = ( m30 * a4 - m31 * a2 + m33 * a0) * invDet;
|
||||||
|
out[11] = (-m20 * a4 + m21 * a2 - m23 * a0) * invDet;
|
||||||
|
|
||||||
|
out[12] = (-m10 * b3 + m11 * b1 - m12 * b0) * invDet;
|
||||||
|
out[13] = ( m00 * b3 - m01 * b1 + m02 * b0) * invDet;
|
||||||
|
out[14] = (-m30 * a3 + m31 * a1 - m32 * a0) * invDet;
|
||||||
|
out[15] = ( m20 * a3 - m21 * a1 + m22 * a0) * invDet;
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
static perspective( fovRadians, aspect, near, far ) {
|
||||||
|
|
||||||
|
const f = 1.0 / Math.tan( fovRadians / 2 );
|
||||||
|
|
||||||
|
const nf = 1 / ( near - far );
|
||||||
|
|
||||||
|
return new Float32Array([
|
||||||
|
f / aspect, 0, 0, 0,
|
||||||
|
0, f, 0, 0,
|
||||||
|
0, 0, (far + near) * nf, -1,
|
||||||
|
0, 0, (2 * far * near) * nf, 0,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static multiply( a, b ) {
|
||||||
|
const out = new Float32Array(16);
|
||||||
|
|
||||||
|
for ( let col = 0; col < 4; col++ ) {
|
||||||
|
for ( let row = 0; row < 4; row++ ) {
|
||||||
|
let sum = 0;
|
||||||
|
|
||||||
|
for ( let k = 0; k < 4; k++ ) {
|
||||||
|
// a is column-major: element at col k, row row => a[k*4 + row]
|
||||||
|
// b is column-major: element at col col, row k => b[col*4 + k]
|
||||||
|
sum += a[k * 4 + row] * b[col * 4 + k];
|
||||||
|
}
|
||||||
|
|
||||||
|
out[col * 4 + row] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
49
framework/Measure.js
Normal file
49
framework/Measure.js
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
export default class Measure {
|
||||||
|
|
||||||
|
startTimes = {};
|
||||||
|
|
||||||
|
endTimes = {};
|
||||||
|
|
||||||
|
writeToPage = false;
|
||||||
|
|
||||||
|
element = false;
|
||||||
|
|
||||||
|
start ( label ) {
|
||||||
|
|
||||||
|
this.startTimes[ label ] = performance.now();
|
||||||
|
}
|
||||||
|
|
||||||
|
end ( label ) {
|
||||||
|
|
||||||
|
this.endTimes[ label ] = performance.now();
|
||||||
|
|
||||||
|
this.log( label );
|
||||||
|
}
|
||||||
|
|
||||||
|
getElapsed ( label ) {
|
||||||
|
|
||||||
|
if ( this.startTimes[ label ] === undefined || this.endTimes[ label ] === undefined ) {
|
||||||
|
|
||||||
|
throw new Error( "Start or end time missing for label: " + label );
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.endTimes[ label ] - this.startTimes[ label ];
|
||||||
|
}
|
||||||
|
|
||||||
|
log ( label ) {
|
||||||
|
|
||||||
|
const elapsed = this.getElapsed( label );
|
||||||
|
|
||||||
|
if( this.writeToPage ) {
|
||||||
|
|
||||||
|
var p = document.createElement("p")
|
||||||
|
|
||||||
|
p.innerText = label + " took " + elapsed.toFixed(3) + " ms";
|
||||||
|
|
||||||
|
this.element.appendChild( p );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log( label + " took " + elapsed.toFixed(3) + " ms" );
|
||||||
|
}
|
||||||
|
}
|
||||||
11
framework/Request.js
Normal file
11
framework/Request.js
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
export class Request {
|
||||||
|
|
||||||
|
constructor( method, payload = {} ) {
|
||||||
|
|
||||||
|
this.method = method; // method name to call on Controller, e.g. "Ping"
|
||||||
|
|
||||||
|
this.payload = payload; // any data for the method
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
67
framework/ShaderInpector.js
Normal file
67
framework/ShaderInpector.js
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
|
||||||
|
class shaderDebugger{
|
||||||
|
|
||||||
|
setup() {
|
||||||
|
|
||||||
|
var shaders = document.shaders;
|
||||||
|
|
||||||
|
var select = document.querySelector(".selectDebugShader");
|
||||||
|
|
||||||
|
for (var i = 0; i < shaders.length; i++) {
|
||||||
|
|
||||||
|
var currentShader = shaders[i];
|
||||||
|
|
||||||
|
var option = document.createElement("option");
|
||||||
|
|
||||||
|
option.innerText = currentShader.path;
|
||||||
|
|
||||||
|
option.id = i;
|
||||||
|
|
||||||
|
select.appendChild( option );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
document.querySelector( "#showBuffers" ).addEventListener( "click", async function() {
|
||||||
|
|
||||||
|
var select = document.querySelector(".selectDebugShader");
|
||||||
|
|
||||||
|
var selectedIndex = select.selectedIndex;
|
||||||
|
|
||||||
|
var selectedShader = document.shaders[ selectedIndex ]
|
||||||
|
|
||||||
|
const keysArray = Array.from( selectedShader.buffers );
|
||||||
|
|
||||||
|
console.log("\n\n\n\n -------------------- Debugging Shader --------------- \n\n\n\n");
|
||||||
|
|
||||||
|
console.log( "Shader Path: ", selectedShader.path );
|
||||||
|
|
||||||
|
console.log( selectedShader );
|
||||||
|
|
||||||
|
for (var i = 0; i < keysArray.length; i++) {
|
||||||
|
|
||||||
|
const bindingInfo = selectedShader.bindings.find( b => b.varName === keysArray[i][0] );
|
||||||
|
|
||||||
|
|
||||||
|
if( bindingInfo ) {
|
||||||
|
|
||||||
|
if( bindingInfo.type == "storage" ) {
|
||||||
|
|
||||||
|
await selectedShader.debugBuffer( keysArray[i][0] );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
console.log("this is a Uniform", keysArray, selectedShader.bindings);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
export default shaderDebugger;
|
||||||
60
framework/Vector3.js
Normal file
60
framework/Vector3.js
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
export default class Vector3 {
|
||||||
|
|
||||||
|
x = 0;
|
||||||
|
y = 0;
|
||||||
|
z = 0;
|
||||||
|
|
||||||
|
constructor( x = 0, y = 0, z = 0 ) {
|
||||||
|
|
||||||
|
this.x = x;
|
||||||
|
this.y = y;
|
||||||
|
this.z = z;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static subtract( a, b ) {
|
||||||
|
|
||||||
|
return new Vector3( a.x - b.x, a.y - b.y, a.z - b.z );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
length() {
|
||||||
|
|
||||||
|
return Math.sqrt( this.x * this.x + this.y * this.y + this.z * this.z );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static cross( a, b ) {
|
||||||
|
|
||||||
|
return new Vector3(
|
||||||
|
|
||||||
|
a.y * b.z - a.z * b.y,
|
||||||
|
a.z * b.x - a.x * b.z,
|
||||||
|
a.x * b.y - a.y * b.x
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static dot( a, b ) {
|
||||||
|
|
||||||
|
return a.x * b.x + a.y * b.y + a.z * b.z;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static normalize( v ) {
|
||||||
|
|
||||||
|
const length = Math.sqrt( v.x * v.x + v.y * v.y + v.z * v.z );
|
||||||
|
|
||||||
|
if ( length > 0.00001 ) {
|
||||||
|
|
||||||
|
return new Vector3( v.x / length, v.y / length, v.z / length );
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
return new Vector3( 0, 0, 0 );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
1895
framework/WebGpu.js
Normal file
1895
framework/WebGpu.js
Normal file
File diff suppressed because it is too large
Load Diff
1960
framework/WebGpu_node.js
Normal file
1960
framework/WebGpu_node.js
Normal file
File diff suppressed because it is too large
Load Diff
113
framework/eventManager.js
Normal file
113
framework/eventManager.js
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// eventManager.js
|
||||||
|
|
||||||
|
export default class EventManager {
|
||||||
|
|
||||||
|
isDragging = false;
|
||||||
|
|
||||||
|
lastX = 0;
|
||||||
|
|
||||||
|
lastY = 0;
|
||||||
|
|
||||||
|
camera;
|
||||||
|
|
||||||
|
canvas;
|
||||||
|
|
||||||
|
setCanvas( canvas ) {
|
||||||
|
|
||||||
|
this.canvas = canvas;
|
||||||
|
|
||||||
|
|
||||||
|
//this.registerEventListeners();
|
||||||
|
|
||||||
|
//this.handleResize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
setup( canvas, camera ) {
|
||||||
|
|
||||||
|
this.canvas = canvas;
|
||||||
|
|
||||||
|
this.camera = camera;
|
||||||
|
|
||||||
|
//this.registerEventListeners();
|
||||||
|
|
||||||
|
//this.handleResize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
registerEventListeners() {
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mousedown", this.onMouseDown.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mouseup", this.onMouseUp.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mouseleave", this.onMouseLeave.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mousemove", this.onMouseMove.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "wheel", this.onWheel.bind(this), { passive: false } );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
resize( event ) {
|
||||||
|
|
||||||
|
this.canvas.width = event.width;
|
||||||
|
|
||||||
|
this.canvas.height = event.height;
|
||||||
|
|
||||||
|
//this.canvas.width = window.innerWidth;
|
||||||
|
|
||||||
|
//this.canvas.height = window.innerHeight;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mousedown( event ) {
|
||||||
|
|
||||||
|
console.log("mouseDownHandler");
|
||||||
|
|
||||||
|
this.isDragging = true;
|
||||||
|
|
||||||
|
this.lastX = event.clientX;
|
||||||
|
|
||||||
|
this.lastY = event.clientY;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mouseup( event ) {
|
||||||
|
|
||||||
|
this.isDragging = false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mouseleave( event ) {
|
||||||
|
|
||||||
|
this.isDragging = false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mousemove( event ) {
|
||||||
|
|
||||||
|
if ( !this.isDragging ) return;
|
||||||
|
|
||||||
|
const deltaX = ( event.clientX - this.lastX ) * 0.005;
|
||||||
|
|
||||||
|
const deltaY = ( event.clientY - this.lastY ) * 0.005;
|
||||||
|
|
||||||
|
this.camera.rotate( deltaX, -deltaY );
|
||||||
|
|
||||||
|
this.lastX = event.clientX;
|
||||||
|
|
||||||
|
this.lastY = event.clientY;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
wheel( event ) {
|
||||||
|
|
||||||
|
|
||||||
|
const delta = event.deltaY * 0.01;
|
||||||
|
|
||||||
|
this.camera.zoom( delta );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
214
framework/eventManager_node.js
Normal file
214
framework/eventManager_node.js
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
// eventManager.js
|
||||||
|
|
||||||
|
import sdl from '@kmamal/sdl'
|
||||||
|
|
||||||
|
var isNode = true;
|
||||||
|
|
||||||
|
if ( typeof window === 'undefined' ) {
|
||||||
|
|
||||||
|
//isNode = false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("isNode", isNode);
|
||||||
|
|
||||||
|
export default class EventManager {
|
||||||
|
|
||||||
|
isDragging = false;
|
||||||
|
|
||||||
|
lastX = 0;
|
||||||
|
|
||||||
|
lastY = 0;
|
||||||
|
|
||||||
|
camera;
|
||||||
|
|
||||||
|
canvas;
|
||||||
|
|
||||||
|
setCanvas( canvas ) {
|
||||||
|
|
||||||
|
this.canvas = canvas;
|
||||||
|
|
||||||
|
|
||||||
|
//this.registerEventListeners();
|
||||||
|
|
||||||
|
//this.handleResize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
setup( canvas, camera ) {
|
||||||
|
|
||||||
|
this.canvas = canvas;
|
||||||
|
|
||||||
|
this.camera = camera;
|
||||||
|
|
||||||
|
//this.registerEventListeners();
|
||||||
|
|
||||||
|
//this.handleResize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
registerEventListeners() {
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mousedown", this.onMouseDown.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mouseup", this.onMouseUp.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mouseleave", this.onMouseLeave.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "mousemove", this.onMouseMove.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.addEventListener( "wheel", this.onWheel.bind(this), { passive: false } );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
registerEventListenersNode() {
|
||||||
|
|
||||||
|
var that = this;
|
||||||
|
|
||||||
|
this.canvas.on('mouseMove', function( event ) {
|
||||||
|
|
||||||
|
that.mousemove( event )
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
this.canvas.on('mouseButtonDown', function( event ) {
|
||||||
|
|
||||||
|
that.mousedown( event )
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
this.canvas.on('mouseButtonUp', function( event ) {
|
||||||
|
|
||||||
|
that.mouseup( event )
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
this.canvas.on( "mouseButtonDown", this.onMouseDown.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.on( "mouseButtonUp", this.onMouseUp.bind(this) );
|
||||||
|
|
||||||
|
//this.canvas.on( "mouseleave", this.onMouseLeave.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.on( "mouseMove", this.onMouseMove.bind(this) );
|
||||||
|
|
||||||
|
this.canvas.on( "mouseWheel", this.onWheel.bind(this), { passive: false } );
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
resize( event ) {
|
||||||
|
|
||||||
|
this.canvas.width = event.width;
|
||||||
|
|
||||||
|
this.canvas.height = event.height;
|
||||||
|
|
||||||
|
//this.canvas.width = window.innerWidth;
|
||||||
|
|
||||||
|
//this.canvas.height = window.innerHeight;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mousedown( event ) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
this.isDragging = true;
|
||||||
|
|
||||||
|
if( isNode ) {
|
||||||
|
|
||||||
|
var mouseX = event.x;
|
||||||
|
|
||||||
|
var mouseY = event.y;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
var mouseX = event.clientX;
|
||||||
|
|
||||||
|
var mouseY = event.clientY;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//console.log("mouseDownHandler", mouseX, mouseY);
|
||||||
|
|
||||||
|
if( isNode ) {
|
||||||
|
|
||||||
|
this.lastX = mouseX;
|
||||||
|
|
||||||
|
this.lastY = mouseY;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
this.lastX = mouseX;
|
||||||
|
|
||||||
|
this.lastY = mouseY;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mouseup( event ) {
|
||||||
|
|
||||||
|
this.isDragging = false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mouseleave( event ) {
|
||||||
|
|
||||||
|
this.isDragging = false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
mousemove( event ) {
|
||||||
|
|
||||||
|
if( isNode ) {
|
||||||
|
|
||||||
|
var mouseX = event.x;
|
||||||
|
|
||||||
|
var mouseY = event.y;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
var mouseX = event.clientX;
|
||||||
|
|
||||||
|
var mouseY = event.clientY;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if ( !this.isDragging ) return;
|
||||||
|
|
||||||
|
const deltaX = ( mouseX - this.lastX ) * 0.005;
|
||||||
|
|
||||||
|
const deltaY = ( mouseY - this.lastY ) * 0.005;
|
||||||
|
|
||||||
|
//console.log("mousemove", mouseX, mouseY);
|
||||||
|
|
||||||
|
this.camera.rotate( deltaX, -deltaY );
|
||||||
|
|
||||||
|
this.lastX = mouseX;
|
||||||
|
|
||||||
|
this.lastY = mouseY;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
wheel( event ) {
|
||||||
|
|
||||||
|
|
||||||
|
const delta = event.deltaY * 0.01;
|
||||||
|
|
||||||
|
this.camera.zoom( delta );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
6
framework/package.json
Normal file
6
framework/package.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"type": "module",
|
||||||
|
"dependencies": {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
22
index.html
Normal file
22
index.html
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>WebGPU MNIST Training (Batch 64)</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: sans-serif; background:#111; color:#eee; }
|
||||||
|
#preview { display:flex; gap:12px; margin-top:1rem; flex-wrap:wrap; }
|
||||||
|
.item { text-align:center; font-size:14px; }
|
||||||
|
canvas { image-rendering: pixelated; background:#000; border:1px solid #555; }
|
||||||
|
button { font-size:16px; padding:8px 14px; cursor:pointer; }
|
||||||
|
#log { white-space:pre; font-size:14px; max-height:240px; overflow-y:auto; background:#222; padding:8px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h2>WebGPU MNIST Training — Batch 64</h2>
|
||||||
|
<button id="trainBtn">Train</button>
|
||||||
|
<div id="log"></div>
|
||||||
|
<div id="preview"></div>
|
||||||
|
<script type="module" src="./training.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
27
mnist_loader.js
Normal file
27
mnist_loader.js
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
export async function loadMNIST(device, sampleCount = 1000) {
|
||||||
|
async function loadFile(path) {
|
||||||
|
const res = await fetch(path);
|
||||||
|
return new Uint8Array(await res.arrayBuffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
const imgRaw = await loadFile("../../models/train-images-idx3-ubyte");
|
||||||
|
const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte");
|
||||||
|
|
||||||
|
const header = 16;
|
||||||
|
const fullCount = lblRaw.length - 8;
|
||||||
|
const count = Math.min(sampleCount, fullCount);
|
||||||
|
|
||||||
|
const images = new Float32Array(count * 28 * 28);
|
||||||
|
const labels = new Uint32Array(count);
|
||||||
|
|
||||||
|
for (let i = 0; i < count; i++) {
|
||||||
|
labels[i] = lblRaw[8 + i];
|
||||||
|
const src = header + i * 784;
|
||||||
|
const dst = i * 784;
|
||||||
|
for (let j = 0; j < 784; j++) {
|
||||||
|
images[dst + j] = imgRaw[src + j] / 255;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { images, labels, count };
|
||||||
|
}
|
||||||
BIN
models/t10k-images-idx3-ubyte
Normal file
BIN
models/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
models/t10k-labels-idx1-ubyte
Normal file
BIN
models/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
models/train-images-idx3-ubyte
Normal file
BIN
models/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
models/train-labels-idx1-ubyte
Normal file
BIN
models/train-labels-idx1-ubyte
Normal file
Binary file not shown.
3
package.json
Normal file
3
package.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"type":"module"
|
||||||
|
}
|
||||||
168
server.js
Normal file
168
server.js
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import http from "http";
|
||||||
|
|
||||||
|
import { readdir } from "fs/promises";
|
||||||
|
import { stat } from "fs/promises";
|
||||||
|
import { readFile } from "fs/promises";
|
||||||
|
import { join } from "path";
|
||||||
|
import { dirname } from "path";
|
||||||
|
import { fileURLToPath } from "url";
|
||||||
|
|
||||||
|
|
||||||
|
class App
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
constructor( )
|
||||||
|
{
|
||||||
|
|
||||||
|
const selfPath = fileURLToPath( import.meta.url );
|
||||||
|
|
||||||
|
this.rootPath = dirname( selfPath );
|
||||||
|
this.httpServer = null;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async start( )
|
||||||
|
{
|
||||||
|
|
||||||
|
this.httpServer = http.createServer( this.handleRequest.bind( this ) );
|
||||||
|
|
||||||
|
this.httpServer.listen( 2020 );
|
||||||
|
|
||||||
|
console.log("Server started on port http://localhost:2020/");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async handleRequest( req, res )
|
||||||
|
{
|
||||||
|
|
||||||
|
const requestedPath = decodeURI( req.url );
|
||||||
|
const fullPath = join( this.rootPath, requestedPath );
|
||||||
|
|
||||||
|
const exists = await this.checkFileExists( fullPath );
|
||||||
|
|
||||||
|
if ( !exists )
|
||||||
|
{
|
||||||
|
|
||||||
|
res.statusCode = 404;
|
||||||
|
res.end( "Not Found" );
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
const stats = await stat( fullPath );
|
||||||
|
|
||||||
|
if ( stats.isDirectory( ) )
|
||||||
|
{
|
||||||
|
|
||||||
|
const indexPath = join( fullPath, "index.html" );
|
||||||
|
const indexExists = await this.checkFileExists( indexPath );
|
||||||
|
|
||||||
|
if ( indexExists )
|
||||||
|
{
|
||||||
|
|
||||||
|
await this.sendFile( indexPath, res );
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.sendDirectoryListing( fullPath, requestedPath, res );
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.sendFile( fullPath, res );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async sendFile( path, res )
|
||||||
|
{
|
||||||
|
|
||||||
|
const contentType = this.getContentType( path );
|
||||||
|
const fileData = await readFile( path );
|
||||||
|
|
||||||
|
res.setHeader( "Content-Type", contentType );
|
||||||
|
res.statusCode = 200;
|
||||||
|
|
||||||
|
res.end( fileData );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async sendDirectoryListing( dirPath, urlPath, res )
|
||||||
|
{
|
||||||
|
|
||||||
|
const entries = await readdir( dirPath, { withFileTypes : true } );
|
||||||
|
|
||||||
|
let html = "<html><body><h1>Index of " + urlPath + "</h1><ul>";
|
||||||
|
|
||||||
|
let i = 0;
|
||||||
|
|
||||||
|
while ( i < entries.length )
|
||||||
|
{
|
||||||
|
|
||||||
|
const e = entries[ i ].name;
|
||||||
|
const link = urlPath.endsWith( "/" )
|
||||||
|
? urlPath + e
|
||||||
|
: urlPath + "/" + e;
|
||||||
|
|
||||||
|
html = html + "<li><a href=\"" + link + "\">" + e + "</a></li>";
|
||||||
|
|
||||||
|
i = i + 1;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
html = html + "</ul></body></html>";
|
||||||
|
|
||||||
|
res.setHeader( "Content-Type", "text/html" );
|
||||||
|
res.statusCode = 200;
|
||||||
|
|
||||||
|
res.end( html );
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async checkFileExists( path )
|
||||||
|
{
|
||||||
|
|
||||||
|
const exists = await stat( path )
|
||||||
|
.then( function( ) { return true; } )
|
||||||
|
.catch( function( ) { return false; } );
|
||||||
|
|
||||||
|
return exists;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
getContentType( path )
|
||||||
|
{
|
||||||
|
|
||||||
|
const lower = path.toLowerCase( );
|
||||||
|
|
||||||
|
if ( lower.endsWith( ".html" ) ) return "text/html";
|
||||||
|
if ( lower.endsWith( ".css" ) ) return "text/css";
|
||||||
|
if ( lower.endsWith( ".js" ) ) return "text/javascript";
|
||||||
|
if ( lower.endsWith( ".json" ) ) return "application/json";
|
||||||
|
if ( lower.endsWith( ".wasm" ) ) return "application/wasm";
|
||||||
|
if ( lower.endsWith( ".png" ) ) return "image/png";
|
||||||
|
if ( lower.endsWith( ".jpg" ) ) return "image/jpeg";
|
||||||
|
if ( lower.endsWith( ".jpeg" ) ) return "image/jpeg";
|
||||||
|
if ( lower.endsWith( ".gif" ) ) return "image/gif";
|
||||||
|
if ( lower.endsWith( ".svg" ) ) return "image/svg+xml";
|
||||||
|
if ( lower.endsWith( ".wgsl" ) ) return "text/plain";
|
||||||
|
if ( lower.endsWith( ".txt" ) ) return "text/plain";
|
||||||
|
|
||||||
|
return "application/octet-stream";
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const app = new App( );
|
||||||
|
|
||||||
|
await app.start( );
|
||||||
33
shaders/neural/forward_dense_layer1.wgsl
Normal file
33
shaders/neural/forward_dense_layer1.wgsl
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
@group(0) @binding(0) var<storage, read> inputImages : array<f32>;
|
||||||
|
@group(0) @binding(1) var<storage, read_write> weightLayer1 : array<f32>;
|
||||||
|
@group(0) @binding(2) var<storage, read_write> biasLayer1 : array<f32>;
|
||||||
|
@group(0) @binding(3) var<storage, read_write> hiddenActivations : array<f32>;
|
||||||
|
@group(0) @binding(4) var<uniform> layer1Sizes : vec3<u32>;
|
||||||
|
// x = inputDim, y = hiddenDim, z = batchSize
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
|
||||||
|
let index = global_id.x;
|
||||||
|
|
||||||
|
let inputDim = layer1Sizes.x;
|
||||||
|
let hiddenDim = layer1Sizes.y;
|
||||||
|
let batchSize = layer1Sizes.z;
|
||||||
|
|
||||||
|
let totalElements = batchSize * hiddenDim;
|
||||||
|
if (index >= totalElements) { return; }
|
||||||
|
|
||||||
|
let sampleIndex = index / hiddenDim;
|
||||||
|
let hiddenIndex = index % hiddenDim;
|
||||||
|
|
||||||
|
var sum : f32 = biasLayer1[hiddenIndex];
|
||||||
|
|
||||||
|
let baseInputIndex = sampleIndex * inputDim;
|
||||||
|
let baseWeightIndex = hiddenIndex * inputDim;
|
||||||
|
|
||||||
|
for (var i : u32 = 0u; i < inputDim; i++) {
|
||||||
|
sum += inputImages[baseInputIndex + i] *
|
||||||
|
weightLayer1[baseWeightIndex + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenActivations[index] = max(sum, 0.0);
|
||||||
|
}
|
||||||
33
shaders/neural/forward_dense_layer2.wgsl
Normal file
33
shaders/neural/forward_dense_layer2.wgsl
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
@group(0) @binding(0) var<storage, read> hiddenActivations : array<f32>;
|
||||||
|
@group(0) @binding(1) var<storage, read_write> weightLayer2 : array<f32>;
|
||||||
|
@group(0) @binding(2) var<storage, read_write> biasLayer2 : array<f32>;
|
||||||
|
@group(0) @binding(3) var<storage, read_write> outputLogits : array<f32>;
|
||||||
|
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||||||
|
// x = hiddenDim, y = numClasses, z = batchSize
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||||
|
let index = id.x;
|
||||||
|
|
||||||
|
let hiddenDim = layer2Sizes.x;
|
||||||
|
let numClasses = layer2Sizes.y;
|
||||||
|
let batchSize = layer2Sizes.z;
|
||||||
|
|
||||||
|
let total = batchSize * numClasses;
|
||||||
|
if (index >= total) { return; }
|
||||||
|
|
||||||
|
let sampleIndex = index / numClasses;
|
||||||
|
let classIndex = index % numClasses;
|
||||||
|
|
||||||
|
var sum : f32 = biasLayer2[classIndex];
|
||||||
|
|
||||||
|
let baseHiddenIndex = sampleIndex * hiddenDim;
|
||||||
|
let baseWeightIndex = classIndex * hiddenDim;
|
||||||
|
|
||||||
|
for (var i : u32 = 0u; i < hiddenDim; i++) {
|
||||||
|
sum += hiddenActivations[baseHiddenIndex + i] *
|
||||||
|
weightLayer2[baseWeightIndex + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
outputLogits[index] = sum;
|
||||||
|
}
|
||||||
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
85
shaders/neural/gradients_dense_layer1.wgsl
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read> inputImages : array<f32>;
|
||||||
|
// [batchSize * inputDimension]
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read> gradientHidden : array<f32>;
|
||||||
|
// [batchSize * hiddenDimension]
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> gradientWeightLayer1 : array<f32>;
|
||||||
|
// [inputDimension * hiddenDimension]
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<storage, read_write> gradientBiasLayer1 : array<f32>;
|
||||||
|
// [hiddenDimension]
|
||||||
|
|
||||||
|
@group(0) @binding(4)
|
||||||
|
var<uniform> layer1Sizes : vec3<u32>;
|
||||||
|
// x = inputDimension, y = hiddenDimension, z = batchSize
|
||||||
|
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||||
|
|
||||||
|
let index = id.x;
|
||||||
|
|
||||||
|
let inputDimension = layer1Sizes.x;
|
||||||
|
let hiddenDimension = layer1Sizes.y;
|
||||||
|
let batchSize = layer1Sizes.z;
|
||||||
|
|
||||||
|
// We launch:
|
||||||
|
// inputDimension * hiddenDimension threads for weight gradients
|
||||||
|
// + hiddenDimension threads for bias gradients
|
||||||
|
let totalWeightElements = inputDimension * hiddenDimension;
|
||||||
|
let totalBiasElements = hiddenDimension;
|
||||||
|
let totalElements = totalWeightElements + totalBiasElements;
|
||||||
|
|
||||||
|
if (index >= totalElements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ----
|
||||||
|
if (index < totalWeightElements) {
|
||||||
|
|
||||||
|
// Map flat index into (inputIndex, hiddenIndex)
|
||||||
|
let hiddenIndex : u32 = index / inputDimension;
|
||||||
|
let inputIndex : u32 = index % inputDimension;
|
||||||
|
|
||||||
|
var sum : f32 = 0.0;
|
||||||
|
|
||||||
|
// dW1[inputIndex, hiddenIndex] = sum over batch of
|
||||||
|
// inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex]
|
||||||
|
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||||
|
|
||||||
|
let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ];
|
||||||
|
|
||||||
|
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ];
|
||||||
|
|
||||||
|
sum = sum + inputValue * hiddenGradient;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
gradientWeightLayer1[ index ] = sum;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ----
|
||||||
|
let biasIndex : u32 = index - totalWeightElements;
|
||||||
|
|
||||||
|
if (biasIndex < hiddenDimension) {
|
||||||
|
|
||||||
|
var sumBias : f32 = 0.0;
|
||||||
|
|
||||||
|
// db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex]
|
||||||
|
for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) {
|
||||||
|
|
||||||
|
let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ];
|
||||||
|
|
||||||
|
sumBias = sumBias + hiddenGradient;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
gradientBiasLayer1[ biasIndex ] = sumBias;
|
||||||
|
}
|
||||||
|
}
|
||||||
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
81
shaders/neural/gradients_dense_layer2.wgsl
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
// gradients_dense_layer2.wgsl
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read> hiddenActivations : array<f32>; // [batch * hiddenDim]
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read> gradientLogits : array<f32>; // [batch * numClasses]
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read> weightLayer2 : array<f32>; // [hiddenDim * numClasses] <-- Missing before
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<storage, read_write> gradientWeightLayer2 : array<f32>; // [hiddenDim * numClasses]
|
||||||
|
|
||||||
|
@group(0) @binding(4)
|
||||||
|
var<storage, read_write> gradientBiasLayer2 : array<f32>; // [numClasses]
|
||||||
|
|
||||||
|
@group(0) @binding(5)
|
||||||
|
var<storage, read_write> gradientHidden : array<f32>; // [batch * hiddenDim]
|
||||||
|
|
||||||
|
@group(0) @binding(6)
|
||||||
|
var<uniform> layer2Sizes : vec3<u32>;
|
||||||
|
// x = hiddenDim, y = numClasses, z = batchSize
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) id : vec3<u32>) {
|
||||||
|
|
||||||
|
let globalIndex = id.x;
|
||||||
|
|
||||||
|
let hiddenDim = layer2Sizes.x;
|
||||||
|
let numClasses = layer2Sizes.y;
|
||||||
|
let batchSize = layer2Sizes.z;
|
||||||
|
|
||||||
|
//
|
||||||
|
// (1) Compute gradients for W2 and b2
|
||||||
|
//
|
||||||
|
let totalWeights = hiddenDim * numClasses;
|
||||||
|
|
||||||
|
if (globalIndex < totalWeights) {
|
||||||
|
let classIndex = globalIndex / hiddenDim;
|
||||||
|
let hiddenIndex = globalIndex % hiddenDim;
|
||||||
|
|
||||||
|
var sum : f32 = 0.0;
|
||||||
|
|
||||||
|
for (var b : u32 = 0u; b < batchSize; b++) {
|
||||||
|
let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex];
|
||||||
|
let gradLogitValue = gradientLogits[b * numClasses + classIndex];
|
||||||
|
sum += activationValue * gradLogitValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
gradientWeightLayer2[globalIndex] = sum;
|
||||||
|
|
||||||
|
if (hiddenIndex == 0u) {
|
||||||
|
var biasSum : f32 = 0.0;
|
||||||
|
for (var b2 : u32 = 0u; b2 < batchSize; b2++) {
|
||||||
|
biasSum += gradientLogits[b2 * numClasses + classIndex];
|
||||||
|
}
|
||||||
|
gradientBiasLayer2[classIndex] = biasSum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// (2) Compute gradientHidden for layer1
|
||||||
|
//
|
||||||
|
let dhIndex = globalIndex - totalWeights;
|
||||||
|
|
||||||
|
if (dhIndex < batchSize * hiddenDim) {
|
||||||
|
let b = dhIndex / hiddenDim;
|
||||||
|
let hiddenIndex = dhIndex % hiddenDim;
|
||||||
|
|
||||||
|
var sum2 : f32 = 0.0;
|
||||||
|
|
||||||
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||||
|
let gradLogitValue = gradientLogits[b * numClasses + c];
|
||||||
|
let weightValue = weightLayer2[c * hiddenDim + hiddenIndex];
|
||||||
|
sum2 += gradLogitValue * weightValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
gradientHidden[dhIndex] = sum2;
|
||||||
|
}
|
||||||
|
}
|
||||||
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
39
shaders/neural/sgd_optimizer_update.wgsl
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
@group(0) @binding(0) var<storage, read_write> weightLayer1 : array<f32>;
|
||||||
|
@group(0) @binding(1) var<storage, read_write> biasLayer1 : array<f32>;
|
||||||
|
@group(0) @binding(2) var<storage, read> gradientWeightLayer1 : array<f32>;
|
||||||
|
@group(0) @binding(3) var<storage, read> gradientBiasLayer1 : array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(4) var<storage, read_write> weightLayer2 : array<f32>;
|
||||||
|
@group(0) @binding(5) var<storage, read_write> biasLayer2 : array<f32>;
|
||||||
|
@group(0) @binding(6) var<storage, read> gradientWeightLayer2 : array<f32>;
|
||||||
|
@group(0) @binding(7) var<storage, read> gradientBiasLayer2 : array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(8) var<uniform> learningRate : f32;
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||||
|
let index = id.x;
|
||||||
|
|
||||||
|
if (index < arrayLength(&weightLayer1)) {
|
||||||
|
weightLayer1[index] -= gradientWeightLayer1[index] * learningRate;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let b1Index = index - arrayLength(&weightLayer1);
|
||||||
|
if (b1Index < arrayLength(&biasLayer1)) {
|
||||||
|
biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1);
|
||||||
|
if (w2Index < arrayLength(&weightLayer2)) {
|
||||||
|
weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let b2Index = w2Index - arrayLength(&weightLayer2);
|
||||||
|
if (b2Index < arrayLength(&biasLayer2)) {
|
||||||
|
biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
39
shaders/neural/softmax_cross_entropy_backward.wgsl
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
@group(0) @binding(0) var<storage, read> outputLogits : array<f32>;
|
||||||
|
@group(0) @binding(1) var<storage, read> targetLabels : array<u32>;
|
||||||
|
@group(0) @binding(2) var<storage, read_write> gradientLogits : array<f32>;
|
||||||
|
@group(0) @binding(3) var<storage, read_write> crossEntropyLoss : array<f32>;
|
||||||
|
@group(0) @binding(4) var<uniform> layer2Sizes : vec3<u32>;
|
||||||
|
// x = hiddenDim, y = numClasses, z = batchSize
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||||
|
let sampleIndex = id.x;
|
||||||
|
let batchSize = layer2Sizes.z;
|
||||||
|
let numClasses = layer2Sizes.y;
|
||||||
|
|
||||||
|
if (sampleIndex >= batchSize) { return; }
|
||||||
|
|
||||||
|
let base = sampleIndex * numClasses;
|
||||||
|
let label = targetLabels[sampleIndex];
|
||||||
|
|
||||||
|
var maxLogit : f32 = -1e9;
|
||||||
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||||
|
let v = outputLogits[base + c];
|
||||||
|
if (v > maxLogit) { maxLogit = v; }
|
||||||
|
}
|
||||||
|
|
||||||
|
var sumExp : f32 = 0.0;
|
||||||
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||||
|
sumExp += exp(outputLogits[base + c] - maxLogit);
|
||||||
|
}
|
||||||
|
|
||||||
|
var loss : f32 = 0.0;
|
||||||
|
for (var c : u32 = 0u; c < numClasses; c++) {
|
||||||
|
let sm = exp(outputLogits[base + c] - maxLogit) / sumExp;
|
||||||
|
let oneHot = f32(c == label);
|
||||||
|
gradientLogits[base + c] = sm - oneHot;
|
||||||
|
if (oneHot == 1.0) { loss = -log(sm); }
|
||||||
|
}
|
||||||
|
|
||||||
|
crossEntropyLoss[sampleIndex] = loss;
|
||||||
|
}
|
||||||
14
source/index_old.html
Normal file
14
source/index_old.html
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>WebGPU Training Demo</title>
|
||||||
|
<base href="Classification/">
|
||||||
|
<script type="module" src="./training.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>WebGPU Linear Regression Training</h1>
|
||||||
|
<button id="trainBtn">Train Model</button>
|
||||||
|
<pre id="log"></pre>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
40
source/loadMNIST.js
Normal file
40
source/loadMNIST.js
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
export async function loadBinary(url) {
|
||||||
|
const response = await fetch(url);
|
||||||
|
if (!response.ok) throw new Error("Failed to load " + url);
|
||||||
|
const buffer = await response.arrayBuffer();
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loads MNIST 28x28 grayscale images + labels
|
||||||
|
export async function loadMNIST() {
|
||||||
|
const imgBuf = await loadBinary("/models/MNIST-2k-images.bin");
|
||||||
|
const lblBuf = await loadBinary("/models/MNIST-2k-labels.bin");
|
||||||
|
|
||||||
|
const images = new Float32Array(imgBuf); // already normalized [0..1]
|
||||||
|
const labels = new Uint8Array(lblBuf);
|
||||||
|
|
||||||
|
const sampleCount = labels.length;
|
||||||
|
console.log("MNIST Loaded:", sampleCount, "samples");
|
||||||
|
|
||||||
|
return { images, labels, sampleCount };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility visualization helper for debugging
|
||||||
|
export function showImage(img784, canvas) {
|
||||||
|
const size = 28;
|
||||||
|
canvas.width = size;
|
||||||
|
canvas.height = size;
|
||||||
|
|
||||||
|
const ctx = canvas.getContext("2d");
|
||||||
|
const imgData = ctx.createImageData(size, size);
|
||||||
|
|
||||||
|
for (let i = 0; i < size * size; i++) {
|
||||||
|
const v = Math.floor(img784[i] * 255);
|
||||||
|
imgData.data[i * 4 + 0] = v;
|
||||||
|
imgData.data[i * 4 + 1] = v;
|
||||||
|
imgData.data[i * 4 + 2] = v;
|
||||||
|
imgData.data[i * 4 + 3] = 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.putImageData(imgData, 0, 0);
|
||||||
|
}
|
||||||
27
source/mnist_loader.js
Normal file
27
source/mnist_loader.js
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
export async function loadMNIST(device, sampleCount = 1000) {
|
||||||
|
async function loadFile(path) {
|
||||||
|
const res = await fetch(path);
|
||||||
|
return new Uint8Array(await res.arrayBuffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
const imgRaw = await loadFile("../../models/train-images-idx3-ubyte");
|
||||||
|
const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte");
|
||||||
|
|
||||||
|
const header = 16;
|
||||||
|
const fullCount = lblRaw.length - 8;
|
||||||
|
const count = Math.min(sampleCount, fullCount);
|
||||||
|
|
||||||
|
const images = new Float32Array(count * 28 * 28);
|
||||||
|
const labels = new Uint32Array(count);
|
||||||
|
|
||||||
|
for (let i = 0; i < count; i++) {
|
||||||
|
labels[i] = lblRaw[8 + i];
|
||||||
|
const src = header + i * 784;
|
||||||
|
const dst = i * 784;
|
||||||
|
for (let j = 0; j < 784; j++) {
|
||||||
|
images[dst + j] = imgRaw[src + j] / 255;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { images, labels, count };
|
||||||
|
}
|
||||||
22
source/train_images.html
Normal file
22
source/train_images.html
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>WebGPU MNIST Training (Batch 64)</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: sans-serif; background:#111; color:#eee; }
|
||||||
|
#preview { display:flex; gap:12px; margin-top:1rem; flex-wrap:wrap; }
|
||||||
|
.item { text-align:center; font-size:14px; }
|
||||||
|
canvas { image-rendering: pixelated; background:#000; border:1px solid #555; }
|
||||||
|
button { font-size:16px; padding:8px 14px; cursor:pointer; }
|
||||||
|
#log { white-space:pre; font-size:14px; max-height:240px; overflow-y:auto; background:#222; padding:8px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h2>WebGPU MNIST Training — Batch 64</h2>
|
||||||
|
<button id="trainBtn">Train</button>
|
||||||
|
<div id="log"></div>
|
||||||
|
<div id="preview"></div>
|
||||||
|
<script type="module" src="./training2.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
111
source/training.js
Normal file
111
source/training.js
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import Shader from "../../framework/WebGpu.js";
|
||||||
|
|
||||||
|
const N = 32; // training samples
|
||||||
|
const inputSize = 3;
|
||||||
|
const outputSize = 1;
|
||||||
|
const lr = 0.01;
|
||||||
|
const epochs = 50;
|
||||||
|
|
||||||
|
function generateTrainingData() {
|
||||||
|
const X = new Float32Array(N * inputSize);
|
||||||
|
const y = new Float32Array(N);
|
||||||
|
|
||||||
|
for (let i = 0; i < N; i++) {
|
||||||
|
const x1 = Math.random() * 2 - 1;
|
||||||
|
const x2 = Math.random() * 2 - 1;
|
||||||
|
const x3 = Math.random() * 2 - 1;
|
||||||
|
|
||||||
|
const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5;
|
||||||
|
|
||||||
|
X[i * 3 + 0] = x1;
|
||||||
|
X[i * 3 + 1] = x2;
|
||||||
|
X[i * 3 + 2] = x3;
|
||||||
|
|
||||||
|
y[i] = label;
|
||||||
|
}
|
||||||
|
return { X, y };
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
if (!navigator.gpu) {
|
||||||
|
alert("WebGPU not supported!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const adapter = await navigator.gpu.requestAdapter();
|
||||||
|
const device = await adapter.requestDevice();
|
||||||
|
|
||||||
|
const { X, y } = generateTrainingData();
|
||||||
|
|
||||||
|
// --- 1: Create shaders ---
|
||||||
|
const forward = new Shader(device);
|
||||||
|
await forward.setup("../../shaders/neural/forward.wgsl");
|
||||||
|
|
||||||
|
const lossBackward = new Shader(device);
|
||||||
|
await lossBackward.setup("../../shaders/neural/loss_backward.wgsl");
|
||||||
|
|
||||||
|
const weightGrad = new Shader(device);
|
||||||
|
await weightGrad.setup("../../shaders/neural/weight_grad.wgsl");
|
||||||
|
|
||||||
|
const sgdUpdate = new Shader(device);
|
||||||
|
await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl");
|
||||||
|
|
||||||
|
// --- 2: Initialize model parameters ---
|
||||||
|
let W = new Float32Array(inputSize).map(() => Math.random() * 0.1);
|
||||||
|
let b = new Float32Array([0]);
|
||||||
|
|
||||||
|
forward.setVariable("input", X);
|
||||||
|
forward.setVariable("W", W);
|
||||||
|
forward.setVariable("b", b);
|
||||||
|
forward.setVariable("N", N);
|
||||||
|
forward.setVariable("y", y);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// link GPU buffers between shaders
|
||||||
|
lossBackward.setBuffer("y", forward.getBuffer("y"));
|
||||||
|
lossBackward.setBuffer("pred", forward.getBuffer("pred"));
|
||||||
|
lossBackward.setVariable("N", N);
|
||||||
|
|
||||||
|
weightGrad.setBuffer("input", forward.getBuffer("input"));
|
||||||
|
weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred"));
|
||||||
|
weightGrad.setVariable("N", N);
|
||||||
|
|
||||||
|
sgdUpdate.setBuffer("W", forward.getBuffer("W"));
|
||||||
|
sgdUpdate.setBuffer("b", forward.getBuffer("b"));
|
||||||
|
sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW"));
|
||||||
|
sgdUpdate.setBuffer("db", weightGrad.getBuffer("db"));
|
||||||
|
sgdUpdate.setVariable("lr", lr);
|
||||||
|
|
||||||
|
const log = document.getElementById("log");
|
||||||
|
|
||||||
|
document.getElementById("trainBtn").onclick = async () => {
|
||||||
|
for (let epoch = 0; epoch < epochs; epoch++) {
|
||||||
|
await forward.execute(N);
|
||||||
|
await lossBackward.execute(N);
|
||||||
|
await weightGrad.execute(N);
|
||||||
|
await sgdUpdate.execute(inputSize);
|
||||||
|
|
||||||
|
const lossArr = await lossBackward.debugBuffer("loss");
|
||||||
|
const loss = lossArr.reduce((a, b) => a + b, 0) / N;
|
||||||
|
|
||||||
|
log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`;
|
||||||
|
|
||||||
|
// Debug learned weights every 5 epochs
|
||||||
|
if (epoch % 5 === 0) {
|
||||||
|
const w = await forward.debugBuffer("W");
|
||||||
|
console.log(
|
||||||
|
`Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const learnedW = await forward.debugBuffer("W");
|
||||||
|
const learnedB = await forward.debugBuffer("b");
|
||||||
|
|
||||||
|
log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`;
|
||||||
|
log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
223
source/training2.js
Normal file
223
source/training2.js
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import Shader from "../../framework/WebGpu.js";
|
||||||
|
import { loadMNIST } from "./mnist_loader.js";
|
||||||
|
|
||||||
|
const batchSize = 1000;
|
||||||
|
const inputDimension = 28 * 28;
|
||||||
|
const hiddenDimension = 128;
|
||||||
|
const numberOfClasses = 10;
|
||||||
|
const learningRateValue = 0.05;
|
||||||
|
const numberOfEpochs = 20;
|
||||||
|
|
||||||
|
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||||||
|
const context = canvas.getContext("2d");
|
||||||
|
const image = context.createImageData(28, 28);
|
||||||
|
|
||||||
|
for (let i = 0; i < inputDimension; i++) {
|
||||||
|
const grayscale = pixelData[i] * 255;
|
||||||
|
const pixelIndex = i * 4;
|
||||||
|
image.data[pixelIndex + 0] = grayscale;
|
||||||
|
image.data[pixelIndex + 1] = grayscale;
|
||||||
|
image.data[pixelIndex + 2] = grayscale;
|
||||||
|
image.data[pixelIndex + 3] = 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
context.putImageData(image, 0, 0);
|
||||||
|
canvas.nextLabel.textContent =
|
||||||
|
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
if (!navigator.gpu) {
|
||||||
|
alert("WebGPU not supported");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const adapter = await navigator.gpu.requestAdapter();
|
||||||
|
const device = await adapter.requestDevice();
|
||||||
|
|
||||||
|
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||||||
|
const mnist = await loadMNIST(device, batchSize);
|
||||||
|
const inputImages = mnist.images;
|
||||||
|
const targetLabels = mnist.labels;
|
||||||
|
|
||||||
|
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||||||
|
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Create Shader Instances ===
|
||||||
|
//
|
||||||
|
const forwardDenseLayer1 = new Shader(device);
|
||||||
|
await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const forwardDenseLayer2 = new Shader(device);
|
||||||
|
await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const softmaxCrossEntropyBackward = new Shader(device);
|
||||||
|
await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer2 = new Shader(device);
|
||||||
|
await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer1 = new Shader(device);
|
||||||
|
await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const sgdOptimizerUpdate = new Shader(device);
|
||||||
|
await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Model Parameters ===
|
||||||
|
//
|
||||||
|
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension)
|
||||||
|
.map(() => (Math.random() * Math.sqrt(2.0 / inputDimension)) - (Math.sqrt(2.0 / inputDimension) / 2));
|
||||||
|
|
||||||
|
|
||||||
|
const biasLayer1 = new Float32Array(hiddenDimension);
|
||||||
|
|
||||||
|
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses)
|
||||||
|
.map(() => (Math.random() * Math.sqrt(2.0 / hiddenDimension)) - (Math.sqrt(2.0 / hiddenDimension) / 2));
|
||||||
|
|
||||||
|
|
||||||
|
const biasLayer2 = new Float32Array(numberOfClasses);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Bind Buffers to Shaders ===
|
||||||
|
//
|
||||||
|
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||||||
|
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||||||
|
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||||||
|
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||||||
|
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||||||
|
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === UI ===
|
||||||
|
//
|
||||||
|
const logOutput = document.getElementById("log");
|
||||||
|
const previewContainer = document.getElementById("preview");
|
||||||
|
const numberOfPreviewImages = 6;
|
||||||
|
|
||||||
|
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||||||
|
() => Math.floor(Math.random() * batchSize));
|
||||||
|
|
||||||
|
const previewCanvases = previewIndices.map(() => {
|
||||||
|
const container = document.createElement("div");
|
||||||
|
container.className = "previewItem";
|
||||||
|
const canvas = document.createElement("canvas");
|
||||||
|
canvas.width = canvas.height = 28;
|
||||||
|
const label = document.createElement("div");
|
||||||
|
canvas.nextLabel = label;
|
||||||
|
container.appendChild(canvas);
|
||||||
|
container.appendChild(label);
|
||||||
|
previewContainer.appendChild(container);
|
||||||
|
return canvas;
|
||||||
|
});
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Training ===
|
||||||
|
//
|
||||||
|
document.getElementById("trainBtn").onclick = async () => {
|
||||||
|
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||||||
|
|
||||||
|
await forwardDenseLayer1.execute(batchSize);
|
||||||
|
await forwardDenseLayer2.execute(batchSize);
|
||||||
|
await softmaxCrossEntropyBackward.execute(batchSize);
|
||||||
|
|
||||||
|
const wgSize = 64;
|
||||||
|
|
||||||
|
// Grad Layer 2
|
||||||
|
await gradientsDenseLayer2.execute(
|
||||||
|
Math.ceil((batchSize * numberOfClasses + hiddenDimension * numberOfClasses) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Grad Layer 1
|
||||||
|
await gradientsDenseLayer1.execute(
|
||||||
|
Math.ceil((batchSize * hiddenDimension + inputDimension * hiddenDimension) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// SGD
|
||||||
|
await sgdOptimizerUpdate.execute(
|
||||||
|
Math.ceil((weightLayer1.length + biasLayer1.length +
|
||||||
|
weightLayer2.length + biasLayer2.length) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Loss
|
||||||
|
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||||||
|
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||||||
|
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||||||
|
|
||||||
|
// Accuracy
|
||||||
|
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||||||
|
let correctCount = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < batchSize; i++) {
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[i*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
if (bestClass === targetLabels[i]) correctCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const accuracy = correctCount / batchSize;
|
||||||
|
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||||||
|
|
||||||
|
// Update preview images
|
||||||
|
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||||||
|
const index = previewIndices[k];
|
||||||
|
const image = inputImages.subarray(
|
||||||
|
index*inputDimension, (index+1)*inputDimension
|
||||||
|
);
|
||||||
|
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[index*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
|
||||||
|
drawMNISTPreview(
|
||||||
|
previewCanvases[k],
|
||||||
|
image,
|
||||||
|
targetLabels[index],
|
||||||
|
bestClass
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
212
source/training2_working.js
Normal file
212
source/training2_working.js
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import Shader from "../../framework/WebGpu.js";
|
||||||
|
import { loadMNIST } from "./mnist_loader.js";
|
||||||
|
|
||||||
|
const batchSize = 1000;
|
||||||
|
const inputDimension = 28 * 28;
|
||||||
|
const hiddenDimension = 128;
|
||||||
|
const numberOfClasses = 10;
|
||||||
|
const learningRateValue = 0.05;
|
||||||
|
const numberOfEpochs = 20;
|
||||||
|
|
||||||
|
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||||||
|
const context = canvas.getContext("2d");
|
||||||
|
const image = context.createImageData(28, 28);
|
||||||
|
|
||||||
|
for (let i = 0; i < inputDimension; i++) {
|
||||||
|
const grayscale = pixelData[i] * 255;
|
||||||
|
const pixelIndex = i * 4;
|
||||||
|
image.data[pixelIndex + 0] = grayscale;
|
||||||
|
image.data[pixelIndex + 1] = grayscale;
|
||||||
|
image.data[pixelIndex + 2] = grayscale;
|
||||||
|
image.data[pixelIndex + 3] = 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
context.putImageData(image, 0, 0);
|
||||||
|
canvas.nextLabel.textContent =
|
||||||
|
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
if (!navigator.gpu) {
|
||||||
|
alert("WebGPU not supported");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const adapter = await navigator.gpu.requestAdapter();
|
||||||
|
const device = await adapter.requestDevice();
|
||||||
|
|
||||||
|
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||||||
|
const mnist = await loadMNIST(device, batchSize);
|
||||||
|
const inputImages = mnist.images;
|
||||||
|
const targetLabels = mnist.labels;
|
||||||
|
|
||||||
|
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||||||
|
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Create Shader Instances ===
|
||||||
|
//
|
||||||
|
const forwardDenseLayer1 = new Shader(device);
|
||||||
|
await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const forwardDenseLayer2 = new Shader(device);
|
||||||
|
await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const softmaxCrossEntropyBackward = new Shader(device);
|
||||||
|
await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer2 = new Shader(device);
|
||||||
|
await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer1 = new Shader(device);
|
||||||
|
await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const sgdOptimizerUpdate = new Shader(device);
|
||||||
|
await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Model Parameters ===
|
||||||
|
//
|
||||||
|
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension).map(() => (Math.random() - 0.5) * 0.05);
|
||||||
|
const biasLayer1 = new Float32Array(hiddenDimension);
|
||||||
|
|
||||||
|
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses).map(() => (Math.random() - 0.5) * 0.05);
|
||||||
|
const biasLayer2 = new Float32Array(numberOfClasses);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Bind Buffers to Shaders ===
|
||||||
|
//
|
||||||
|
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||||||
|
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||||||
|
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||||||
|
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||||||
|
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||||||
|
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === UI ===
|
||||||
|
//
|
||||||
|
const logOutput = document.getElementById("log");
|
||||||
|
const previewContainer = document.getElementById("preview");
|
||||||
|
const numberOfPreviewImages = 6;
|
||||||
|
|
||||||
|
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||||||
|
() => Math.floor(Math.random() * batchSize));
|
||||||
|
|
||||||
|
const previewCanvases = previewIndices.map(() => {
|
||||||
|
const container = document.createElement("div");
|
||||||
|
container.className = "previewItem";
|
||||||
|
const canvas = document.createElement("canvas");
|
||||||
|
canvas.width = canvas.height = 28;
|
||||||
|
const label = document.createElement("div");
|
||||||
|
canvas.nextLabel = label;
|
||||||
|
container.appendChild(canvas);
|
||||||
|
container.appendChild(label);
|
||||||
|
previewContainer.appendChild(container);
|
||||||
|
return canvas;
|
||||||
|
});
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Training ===
|
||||||
|
//
|
||||||
|
document.getElementById("trainBtn").onclick = async () => {
|
||||||
|
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||||||
|
|
||||||
|
await forwardDenseLayer1.execute(batchSize);
|
||||||
|
await forwardDenseLayer2.execute(batchSize);
|
||||||
|
await softmaxCrossEntropyBackward.execute(batchSize);
|
||||||
|
|
||||||
|
await gradientsDenseLayer2.execute(
|
||||||
|
hiddenDimension * numberOfClasses + numberOfClasses + batchSize * hiddenDimension
|
||||||
|
);
|
||||||
|
|
||||||
|
await gradientsDenseLayer1.execute(
|
||||||
|
inputDimension * hiddenDimension + hiddenDimension
|
||||||
|
);
|
||||||
|
|
||||||
|
await sgdOptimizerUpdate.execute(
|
||||||
|
weightLayer1.length + biasLayer1.length +
|
||||||
|
weightLayer2.length + biasLayer2.length
|
||||||
|
);
|
||||||
|
|
||||||
|
// Loss
|
||||||
|
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||||||
|
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||||||
|
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||||||
|
|
||||||
|
// Accuracy
|
||||||
|
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||||||
|
let correctCount = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < batchSize; i++) {
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[i*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
if (bestClass === targetLabels[i]) correctCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const accuracy = correctCount / batchSize;
|
||||||
|
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||||||
|
|
||||||
|
// Update preview images
|
||||||
|
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||||||
|
const index = previewIndices[k];
|
||||||
|
const image = inputImages.subarray(
|
||||||
|
index*inputDimension, (index+1)*inputDimension
|
||||||
|
);
|
||||||
|
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[index*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
|
||||||
|
drawMNISTPreview(
|
||||||
|
previewCanvases[k],
|
||||||
|
image,
|
||||||
|
targetLabels[index],
|
||||||
|
bestClass
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
223
training.js
Normal file
223
training.js
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import Shader from "./framework/WebGpu.js";
|
||||||
|
import { loadMNIST } from "./mnist_loader.js";
|
||||||
|
|
||||||
|
const batchSize = 1000;
|
||||||
|
const inputDimension = 28 * 28;
|
||||||
|
const hiddenDimension = 128;
|
||||||
|
const numberOfClasses = 10;
|
||||||
|
const learningRateValue = 0.05;
|
||||||
|
const numberOfEpochs = 20;
|
||||||
|
|
||||||
|
function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) {
|
||||||
|
const context = canvas.getContext("2d");
|
||||||
|
const image = context.createImageData(28, 28);
|
||||||
|
|
||||||
|
for (let i = 0; i < inputDimension; i++) {
|
||||||
|
const grayscale = pixelData[i] * 255;
|
||||||
|
const pixelIndex = i * 4;
|
||||||
|
image.data[pixelIndex + 0] = grayscale;
|
||||||
|
image.data[pixelIndex + 1] = grayscale;
|
||||||
|
image.data[pixelIndex + 2] = grayscale;
|
||||||
|
image.data[pixelIndex + 3] = 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
context.putImageData(image, 0, 0);
|
||||||
|
canvas.nextLabel.textContent =
|
||||||
|
`GT: ${trueLabel} → Pred: ${predictedLabel}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
if (!navigator.gpu) {
|
||||||
|
alert("WebGPU not supported");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const adapter = await navigator.gpu.requestAdapter();
|
||||||
|
const device = await adapter.requestDevice();
|
||||||
|
|
||||||
|
// Load MNIST subset → inputImages & targetLabels are GPU buffers
|
||||||
|
const mnist = await loadMNIST(device, batchSize);
|
||||||
|
const inputImages = mnist.images;
|
||||||
|
const targetLabels = mnist.labels;
|
||||||
|
|
||||||
|
const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]);
|
||||||
|
const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Create Shader Instances ===
|
||||||
|
//
|
||||||
|
const forwardDenseLayer1 = new Shader(device);
|
||||||
|
await forwardDenseLayer1.setup("./shaders/neural/forward_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const forwardDenseLayer2 = new Shader(device);
|
||||||
|
await forwardDenseLayer2.setup("./shaders/neural/forward_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const softmaxCrossEntropyBackward = new Shader(device);
|
||||||
|
await softmaxCrossEntropyBackward.setup("./shaders/neural/softmax_cross_entropy_backward.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer2 = new Shader(device);
|
||||||
|
await gradientsDenseLayer2.setup("./shaders/neural/gradients_dense_layer2.wgsl");
|
||||||
|
|
||||||
|
const gradientsDenseLayer1 = new Shader(device);
|
||||||
|
await gradientsDenseLayer1.setup("./shaders/neural/gradients_dense_layer1.wgsl");
|
||||||
|
|
||||||
|
const sgdOptimizerUpdate = new Shader(device);
|
||||||
|
await sgdOptimizerUpdate.setup("./shaders/neural/sgd_optimizer_update.wgsl");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Model Parameters ===
|
||||||
|
//
|
||||||
|
const weightLayer1 = new Float32Array(inputDimension * hiddenDimension)
|
||||||
|
.map(() => (Math.random() * Math.sqrt(2.0 / inputDimension)) - (Math.sqrt(2.0 / inputDimension) / 2));
|
||||||
|
|
||||||
|
|
||||||
|
const biasLayer1 = new Float32Array(hiddenDimension);
|
||||||
|
|
||||||
|
const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses)
|
||||||
|
.map(() => (Math.random() * Math.sqrt(2.0 / hiddenDimension)) - (Math.sqrt(2.0 / hiddenDimension) / 2));
|
||||||
|
|
||||||
|
|
||||||
|
const biasLayer2 = new Float32Array(numberOfClasses);
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Bind Buffers to Shaders ===
|
||||||
|
//
|
||||||
|
forwardDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
forwardDenseLayer1.setVariable("weightLayer1", weightLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("biasLayer1", biasLayer1);
|
||||||
|
forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
forwardDenseLayer2.setVariable("weightLayer2", weightLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("biasLayer2", biasLayer2);
|
||||||
|
forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits"));
|
||||||
|
softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels);
|
||||||
|
softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations"));
|
||||||
|
gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null);
|
||||||
|
gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null);
|
||||||
|
//gradientsDenseLayer2.setBuffer("gradientHidden", null);
|
||||||
|
gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform");
|
||||||
|
|
||||||
|
gradientsDenseLayer1.setVariable("inputImages", inputImages);
|
||||||
|
gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden"));
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null);
|
||||||
|
//gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null);
|
||||||
|
gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform");
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2"));
|
||||||
|
sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2"));
|
||||||
|
|
||||||
|
sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform");
|
||||||
|
|
||||||
|
//
|
||||||
|
// === UI ===
|
||||||
|
//
|
||||||
|
const logOutput = document.getElementById("log");
|
||||||
|
const previewContainer = document.getElementById("preview");
|
||||||
|
const numberOfPreviewImages = 6;
|
||||||
|
|
||||||
|
const previewIndices = Array.from({ length: numberOfPreviewImages },
|
||||||
|
() => Math.floor(Math.random() * batchSize));
|
||||||
|
|
||||||
|
const previewCanvases = previewIndices.map(() => {
|
||||||
|
const container = document.createElement("div");
|
||||||
|
container.className = "previewItem";
|
||||||
|
const canvas = document.createElement("canvas");
|
||||||
|
canvas.width = canvas.height = 28;
|
||||||
|
const label = document.createElement("div");
|
||||||
|
canvas.nextLabel = label;
|
||||||
|
container.appendChild(canvas);
|
||||||
|
container.appendChild(label);
|
||||||
|
previewContainer.appendChild(container);
|
||||||
|
return canvas;
|
||||||
|
});
|
||||||
|
|
||||||
|
//
|
||||||
|
// === Training ===
|
||||||
|
//
|
||||||
|
document.getElementById("trainBtn").onclick = async () => {
|
||||||
|
for (let epoch = 0; epoch < numberOfEpochs; epoch++) {
|
||||||
|
|
||||||
|
await forwardDenseLayer1.execute(batchSize);
|
||||||
|
await forwardDenseLayer2.execute(batchSize);
|
||||||
|
await softmaxCrossEntropyBackward.execute(batchSize);
|
||||||
|
|
||||||
|
const wgSize = 64;
|
||||||
|
|
||||||
|
// Grad Layer 2
|
||||||
|
await gradientsDenseLayer2.execute(
|
||||||
|
Math.ceil((batchSize * numberOfClasses + hiddenDimension * numberOfClasses) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Grad Layer 1
|
||||||
|
await gradientsDenseLayer1.execute(
|
||||||
|
Math.ceil((batchSize * hiddenDimension + inputDimension * hiddenDimension) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// SGD
|
||||||
|
await sgdOptimizerUpdate.execute(
|
||||||
|
Math.ceil((weightLayer1.length + biasLayer1.length +
|
||||||
|
weightLayer2.length + biasLayer2.length) / wgSize)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Loss
|
||||||
|
const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss");
|
||||||
|
const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize;
|
||||||
|
logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`;
|
||||||
|
|
||||||
|
// Accuracy
|
||||||
|
const logits = await forwardDenseLayer2.debugBuffer("outputLogits");
|
||||||
|
let correctCount = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < batchSize; i++) {
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[i*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
if (bestClass === targetLabels[i]) correctCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const accuracy = correctCount / batchSize;
|
||||||
|
logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`;
|
||||||
|
|
||||||
|
// Update preview images
|
||||||
|
for (let k = 0; k < numberOfPreviewImages; k++) {
|
||||||
|
const index = previewIndices[k];
|
||||||
|
const image = inputImages.subarray(
|
||||||
|
index*inputDimension, (index+1)*inputDimension
|
||||||
|
);
|
||||||
|
|
||||||
|
let bestValue = -1e9;
|
||||||
|
let bestClass = 0;
|
||||||
|
for (let c = 0; c < numberOfClasses; c++) {
|
||||||
|
const value = logits[index*numberOfClasses+c];
|
||||||
|
if (value > bestValue) { bestValue = value; bestClass = c; }
|
||||||
|
}
|
||||||
|
|
||||||
|
drawMNISTPreview(
|
||||||
|
previewCanvases[k],
|
||||||
|
image,
|
||||||
|
targetLabels[index],
|
||||||
|
bestClass
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
Reference in New Issue
Block a user