commit 22c29a893995554b11a75b9cb79353de42564673 Author: kaj dijkstra Date: Wed Nov 19 10:42:46 2025 +0100 First commit. diff --git a/framework/Camera.js b/framework/Camera.js new file mode 100644 index 0000000..f59d717 --- /dev/null +++ b/framework/Camera.js @@ -0,0 +1,83 @@ +import Vector3 from "./Vector3.js"; +import Matrix4 from "./Matrix4.js"; + +export default class Camera { + + eye = new Vector3(); + target = new Vector3(); + up = new Vector3( 0, 1, 0 ); + + yaw = 0; + pitch = 0; + fovRadians = Math.PI / 4; + near = 0.1; + far = 3000.0; + distance = 10; + viewMatrix = new Float32Array( 16 ); + + constructor( eye = [0, 0, 5], target = [0, 0, 0], up = [0, 1, 0] ) { + + this.eye = new Vector3( ...eye ); + this.target = new Vector3( ...target ); + this.up = new Vector3( ...up ); + + this.distance = Vector3.subtract( this.eye, this.target ).length(); + + this.viewMatrix = Matrix4.lookAt( this.eye, this.target, this.up ); + + } + + update() { + + const x = this.distance * Math.cos( this.pitch ) * Math.sin( this.yaw ); + const y = this.distance * Math.sin( this.pitch ); + const z = this.distance * Math.cos( this.pitch ) * Math.cos( this.yaw ); + + this.eye = new Vector3( + x + this.target.x, + y + this.target.y, + z + this.target.z + ); + + this.viewMatrix = Matrix4.lookAt( this.eye, this.target, this.up ); + + } + + getViewMatrix() { + + return this.viewMatrix; + + } + + rotate( deltaYaw, deltaPitch ) { + + this.yaw += deltaYaw; + this.pitch -= deltaPitch; + + const maxPitch = Math.PI / 2 - 0.01; + + if ( this.pitch > maxPitch ) this.pitch = maxPitch; + if ( this.pitch < -maxPitch ) this.pitch = -maxPitch; + + this.update(); + + } + + zoom( delta ) { + + this.distance += delta * 1; + + if ( this.distance < 0.1 ) this.distance = 0.1; + + this.update(); + + } + + setTarget( target ) { + + this.target = new Vector3( ...target ); + + this.update(); + + } +} diff --git a/framework/Matrix4.js b/framework/Matrix4.js new file mode 100644 index 0000000..83936e7 --- /dev/null +++ b/framework/Matrix4.js @@ -0,0 +1,125 @@ +import Vector3 from "./Vector3.js"; + +export default class Matrix4 { + + static lookAt( eye, target, up ) { + + const zAxis = Vector3.normalize( Vector3.subtract( eye, target ) ); + + const xAxis = Vector3.normalize( Vector3.cross( up, zAxis ) ); + + const yAxis = Vector3.cross( zAxis, xAxis ); + + return new Float32Array([ + xAxis.x, yAxis.x, zAxis.x, 0, + xAxis.y, yAxis.y, zAxis.y, 0, + xAxis.z, yAxis.z, zAxis.z, 0, + -Vector3.dot( xAxis, eye ), -Vector3.dot( yAxis, eye ), -Vector3.dot( zAxis, eye ), 1, + ]); + } + + static getColumn( matrix, index ) { + const i = index * 4; + + return new Vector3( + matrix[ i + 0 ], + matrix[ i + 1 ], + matrix[ i + 2 ] + ); + } + + static createProjectionMatrix( camera, canvas ) { + return Matrix4.perspective( + camera.fovRadians, + canvas.width / canvas.height, + camera.near, + camera.far + ); + } + + static invert( m ) { + const out = new Float32Array(16); + + const m00 = m[0], m01 = m[1], m02 = m[2], m03 = m[3]; + const m10 = m[4], m11 = m[5], m12 = m[6], m13 = m[7]; + const m20 = m[8], m21 = m[9], m22 = m[10], m23 = m[11]; + const m30 = m[12], m31 = m[13], m32 = m[14], m33 = m[15]; + + const a0 = m00 * m11 - m01 * m10; + const a1 = m00 * m12 - m02 * m10; + const a2 = m00 * m13 - m03 * m10; + const a3 = m01 * m12 - m02 * m11; + const a4 = m01 * m13 - m03 * m11; + const a5 = m02 * m13 - m03 * m12; + + const b0 = m20 * m31 - m21 * m30; + const b1 = m20 * m32 - m22 * m30; + const b2 = m20 * m33 - m23 * m30; + const b3 = m21 * m32 - m22 * m31; + const b4 = m21 * m33 - m23 * m31; + const b5 = m22 * m33 - m23 * m32; + + const det = a0 * b5 - a1 * b4 + a2 * b3 + a3 * b2 - a4 * b1 + a5 * b0; + + if (det === 0) return null; + + const invDet = 1 / det; + + out[0] = ( m11 * b5 - m12 * b4 + m13 * b3) * invDet; + out[1] = (-m01 * b5 + m02 * b4 - m03 * b3) * invDet; + out[2] = ( m31 * a5 - m32 * a4 + m33 * a3) * invDet; + out[3] = (-m21 * a5 + m22 * a4 - m23 * a3) * invDet; + + out[4] = (-m10 * b5 + m12 * b2 - m13 * b1) * invDet; + out[5] = ( m00 * b5 - m02 * b2 + m03 * b1) * invDet; + out[6] = (-m30 * a5 + m32 * a2 - m33 * a1) * invDet; + out[7] = ( m20 * a5 - m22 * a2 + m23 * a1) * invDet; + + out[8] = ( m10 * b4 - m11 * b2 + m13 * b0) * invDet; + out[9] = (-m00 * b4 + m01 * b2 - m03 * b0) * invDet; + out[10] = ( m30 * a4 - m31 * a2 + m33 * a0) * invDet; + out[11] = (-m20 * a4 + m21 * a2 - m23 * a0) * invDet; + + out[12] = (-m10 * b3 + m11 * b1 - m12 * b0) * invDet; + out[13] = ( m00 * b3 - m01 * b1 + m02 * b0) * invDet; + out[14] = (-m30 * a3 + m31 * a1 - m32 * a0) * invDet; + out[15] = ( m20 * a3 - m21 * a1 + m22 * a0) * invDet; + + return out; + } + + static perspective( fovRadians, aspect, near, far ) { + + const f = 1.0 / Math.tan( fovRadians / 2 ); + + const nf = 1 / ( near - far ); + + return new Float32Array([ + f / aspect, 0, 0, 0, + 0, f, 0, 0, + 0, 0, (far + near) * nf, -1, + 0, 0, (2 * far * near) * nf, 0, + ]); + } + + static multiply( a, b ) { + const out = new Float32Array(16); + + for ( let col = 0; col < 4; col++ ) { + for ( let row = 0; row < 4; row++ ) { + let sum = 0; + + for ( let k = 0; k < 4; k++ ) { + // a is column-major: element at col k, row row => a[k*4 + row] + // b is column-major: element at col col, row k => b[col*4 + k] + sum += a[k * 4 + row] * b[col * 4 + k]; + } + + out[col * 4 + row] = sum; + } + } + + return out; + } + +} diff --git a/framework/Measure.js b/framework/Measure.js new file mode 100644 index 0000000..e19f2e9 --- /dev/null +++ b/framework/Measure.js @@ -0,0 +1,49 @@ +export default class Measure { + + startTimes = {}; + + endTimes = {}; + + writeToPage = false; + + element = false; + + start ( label ) { + + this.startTimes[ label ] = performance.now(); + } + + end ( label ) { + + this.endTimes[ label ] = performance.now(); + + this.log( label ); + } + + getElapsed ( label ) { + + if ( this.startTimes[ label ] === undefined || this.endTimes[ label ] === undefined ) { + + throw new Error( "Start or end time missing for label: " + label ); + } + + return this.endTimes[ label ] - this.startTimes[ label ]; + } + + log ( label ) { + + const elapsed = this.getElapsed( label ); + + if( this.writeToPage ) { + + var p = document.createElement("p") + + p.innerText = label + " took " + elapsed.toFixed(3) + " ms"; + + this.element.appendChild( p ); + + } + + console.log( label + " took " + elapsed.toFixed(3) + " ms" ); + } +} \ No newline at end of file diff --git a/framework/Request.js b/framework/Request.js new file mode 100644 index 0000000..d172a64 --- /dev/null +++ b/framework/Request.js @@ -0,0 +1,11 @@ +export class Request { + + constructor( method, payload = {} ) { + + this.method = method; // method name to call on Controller, e.g. "Ping" + + this.payload = payload; // any data for the method + + } + +} \ No newline at end of file diff --git a/framework/ShaderInpector.js b/framework/ShaderInpector.js new file mode 100644 index 0000000..a2aba82 --- /dev/null +++ b/framework/ShaderInpector.js @@ -0,0 +1,67 @@ + +class shaderDebugger{ + + setup() { + + var shaders = document.shaders; + + var select = document.querySelector(".selectDebugShader"); + + for (var i = 0; i < shaders.length; i++) { + + var currentShader = shaders[i]; + + var option = document.createElement("option"); + + option.innerText = currentShader.path; + + option.id = i; + + select.appendChild( option ); + + } + + document.querySelector( "#showBuffers" ).addEventListener( "click", async function() { + + var select = document.querySelector(".selectDebugShader"); + + var selectedIndex = select.selectedIndex; + + var selectedShader = document.shaders[ selectedIndex ] + + const keysArray = Array.from( selectedShader.buffers ); + + console.log("\n\n\n\n -------------------- Debugging Shader --------------- \n\n\n\n"); + + console.log( "Shader Path: ", selectedShader.path ); + + console.log( selectedShader ); + + for (var i = 0; i < keysArray.length; i++) { + + const bindingInfo = selectedShader.bindings.find( b => b.varName === keysArray[i][0] ); + + + if( bindingInfo ) { + + if( bindingInfo.type == "storage" ) { + + await selectedShader.debugBuffer( keysArray[i][0] ); + + } + + } else { + + console.log("this is a Uniform", keysArray, selectedShader.bindings); + + } + + } + + }); + } + +} + + +export default shaderDebugger; \ No newline at end of file diff --git a/framework/Vector3.js b/framework/Vector3.js new file mode 100644 index 0000000..3ccd3f9 --- /dev/null +++ b/framework/Vector3.js @@ -0,0 +1,60 @@ +export default class Vector3 { + + x = 0; + y = 0; + z = 0; + + constructor( x = 0, y = 0, z = 0 ) { + + this.x = x; + this.y = y; + this.z = z; + + } + + static subtract( a, b ) { + + return new Vector3( a.x - b.x, a.y - b.y, a.z - b.z ); + + } + + + length() { + + return Math.sqrt( this.x * this.x + this.y * this.y + this.z * this.z ); + + } + + static cross( a, b ) { + + return new Vector3( + + a.y * b.z - a.z * b.y, + a.z * b.x - a.x * b.z, + a.x * b.y - a.y * b.x + ); + + } + + static dot( a, b ) { + + return a.x * b.x + a.y * b.y + a.z * b.z; + + } + + static normalize( v ) { + + const length = Math.sqrt( v.x * v.x + v.y * v.y + v.z * v.z ); + + if ( length > 0.00001 ) { + + return new Vector3( v.x / length, v.y / length, v.z / length ); + + } else { + + return new Vector3( 0, 0, 0 ); + + } + + } +} diff --git a/framework/WebGpu.js b/framework/WebGpu.js new file mode 100644 index 0000000..52717b8 --- /dev/null +++ b/framework/WebGpu.js @@ -0,0 +1,1895 @@ + + + +if( typeof document != "undefined" ) { + + document.shaders = new Array(); + + document.bufferMap = {}; + +} + + +export default class Shader { + + sampleCount = 1; + + device; + + path; + + wgslSource; + + bindGroupLayout; + + bindGroupLayouts = new Map(); + + vertexBuffersDescriptors = new Map(); + + textures = new Map(); + + pipeline; + + bindGroups = new Map(); + + entryPoint = "main"; + + bindings = new Array(); + + buffers = new Map(); + + attributeBuffers ; + + bufferUsages = new Map(); // store usage per buffer to recreate correctly + + _externalBindGroupLayout = null; + + hasLoaded = false; + + stage = GPUShaderStage.COMPUTE; + + isInWorker = typeof document === "undefined"; + + topologyValue = "triangle-strip"; + + binded = false; + + textures = new Map(); + + samplers = new Map(); + + + + constructor( device, path = false ) { + + if( !this.isInWorker ) { + + document.shaders.push( this ); + + } + + console.log("\n\n \n Creating New Shader", path, "\n\n"); + + console.log("device", typeof device); + + this.device = device; + + if( path ) { + + this.path = path; + + } + + } + + set topology(value) { + + if (this.binded) { + + console.error("Define topology before setup", this); + } + + this.topologyValue = value; + + } + + get topology() { + + return this.topologyValue; + + } + + _createAllBuffers() { + + console.log("_createAllBuffers", this.bindings, this.buffers); + + for ( const b of this.bindings ) { + + //console.log("has buffer", b.varName, this.buffers.has( b.varName )); + + if ( this.buffers.has( b.varName ) ) continue; + + let size = 240000 * 3 * 4; + + if ( b.type === "uniform" ) { + + size = 4 * 4; + + } + + let usage; + + if ( b.type === "uniform" ) { + + usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST" ); + + } else if ( b.type === "read-only-storage" ) { + + usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST" ); + + + } else { + + usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC" ); + + } + + const buffer = this.device.createBuffer( { + + size, + + usage, + + label: b.varName + + } ); + + if( !this.isInWorker ) { + + //document.bufferOwners.set( buffer, this ); + + } + + this.buffers.set( b.varName, buffer ); + + this.bufferUsages.set( b.varName, usage ); + + //console.log( `Created buffer for '${b.varName}' size=${size} usage=0x${usage.toString(16)}` ); + + } + + } + +_createDefaultTexturesAndSamplers() { + + for ( const b of this.bindings ) { + + if ( b.type === "texture" && !this.textures.has( b.varName ) ) { + + // Detect if the type is a texture array + const isTextureArray = b.varType.startsWith("texture_2d_array"); + + // Create 1x1 texture with 1 layer (depthOrArrayLayers = 1) + const defaultTexture = this.device.createTexture({ + + size: isTextureArray ? [1, 1, 2] : [1, 1, 1], + + format: "rgba8unorm", + + usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST, + + label: b.varName + "_defaultTexture" + + }); + + + console.log("create texture ", isTextureArray, isTextureArray ? "2d-array" : "2d"); + // Create texture view with correct dimension + const defaultTextureView = defaultTexture.createView({ + + dimension: isTextureArray ? "2d-array" : "2d" + + }); + + console.log("create default texture?", b, defaultTextureView, isTextureArray ? "2d-array" : "2d"); + + this.textures.set(b.varName, defaultTextureView); + + } else if ( b.type === "sampler" && !this.samplers.has( b.varName ) ) { + + // Create default sampler + + const defaultSampler = this.device.createSampler( { + + magFilter: "linear", + + minFilter: "linear", + + label: b.varName + "_defaultSampler" + + } ); + + this.samplers.set( b.varName, defaultSampler ); + + } + + } + +} + + +createBindGroups() { + + this.bindGroups = new Map(); + + for ( const [ group, layout ] of this.bindGroupLayouts.entries() ) { + + const bindings = this.bindings.filter( b => b.group === group ); + + const entries = new Array( bindings.length ); + + for ( let i = 0; i < bindings.length; i++ ) { + + const b = bindings[ i ]; + + let resource; + + if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) { + + const gpuBuffer = this.buffers.get( b.varName ); + + if ( !gpuBuffer ) { + + throw new Error( `Buffer missing for ${b.varName} when creating bind group.` ); + + } + + resource = { buffer: gpuBuffer }; + + } else if ( b.type === "texture" ) { + + const gpuTextureView = this.textures.get( b.varName ); + + console.log("this.textures", this.textures); + + if ( !gpuTextureView ) { + + console.warn( `Texture missing for ${b.varName} when creating bind group.` ); + + } else { + + console.log( `Binding texture ${b.varName} with dimension:`, gpuTextureView.dimension ); + + } + + resource = gpuTextureView; + + } else if ( b.type === "sampler" ) { + + const gpuSampler = this.samplers.get( b.varName ); + + if ( !gpuSampler ) { + + //throw new Error( `Sampler missing for ${b.varName} when creating bind group.` ); + + } + + resource = gpuSampler; + + } else { + + throw new Error( `Unknown binding type '${b.type}' for ${b.varName}` ); + + } + + entries[ i ] = { + + binding: b.binding, + + resource: resource + + }; + + } + + const bindGroup = this.device.createBindGroup( { + + layout: layout, + + entries: entries + + } ); + + this.bindGroups.set( group, bindGroup ); + + } + +} + extractEntryPoints( wgslCode ) { + const entryPoints = []; + + // Regex to match lines like: @vertex fn functionName(...) { ... } + const regex = /@(\w+)(?:\s+@\w+(?:\([^)]*\))?)*\s+fn\s+(\w+)\s*\(/g; + + let match; + while ((match = regex.exec(wgslCode)) !== null) { + const stage = match[1]; + const name = match[2]; + + let type; + + switch (stage) { + case "vertex": + type = GPUShaderStage.VERTEX; + break; + case "fragment": + type = GPUShaderStage.FRAGMENT; + break; + case "compute": + type = GPUShaderStage.COMPUTE; + break; + default: + type = "Unknown"; + } + + entryPoints.push({ name: name, type: type }); + } + + return entryPoints; + + } + + async setup( path ) { + + + + this.path = path; + + this.wgslSource = await this._loadShaderSource( this.path ); + //console.log(this.wgslSource); + var entryPoints = this.extractEntryPoints( this.wgslSource ); + + + + + for (var i = 0; i < entryPoints.length; i++) { + + var entryPoint = entryPoints[i]; + + await this.addStage( entryPoint.name, entryPoint.type ); + + } + + //await this.addStage("computeMain", 4 ); + + //console.log( "load shader", this.wgslSource, entryPoints ) + + } + + async _loadShaderSource( pathName ) { + + const response = await fetch( pathName ); + + if ( !response.ok ) throw new Error( `Failed to load shader: ${ pathName }` ); + + return await response.text(); + + } + + setVertexBuffer( name, dataArray ) { + + const buffer = this.device.createBuffer({ + size: dataArray.byteLength, + usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, + mappedAtCreation: true + }); + + const mapping = new Float32Array( buffer.getMappedRange() ); + + mapping.set( dataArray ); + + buffer.unmap(); + + + + this.buffers.set( name, buffer ); + + } + + extractVertexFunctionParameters(wgsl) { + + let vertexFnIndex = wgsl.indexOf('@vertex'); + + if (vertexFnIndex === -1) return ''; + + let fnIndex = wgsl.indexOf('fn', vertexFnIndex); + + if (fnIndex === -1) return ''; + + let parenStart = wgsl.indexOf('(', fnIndex); + + if (parenStart === -1) return ''; + + let parenCount = 1; + let i = parenStart + 1; + + while (i < wgsl.length && parenCount > 0) { + const c = wgsl[i]; + + if (c === '(') parenCount++; + + else if (c === ')') parenCount--; + + i++; + } + + if (parenCount !== 0) { + // unbalanced parentheses + return ''; + } + + // Extract parameter substring between parenStart+1 and i-1 + let paramsStr = wgsl.substring(parenStart + 1, i - 1); + + return paramsStr.trim(); + } + + extractVertexParamsLocations(wgsl) { + + const paramsStr = this.extractVertexFunctionParameters(wgsl) + + const paramRegex = /@location\s*\(\s*(\d+)\s*\)\s+(\w+)\s*:\s*([\w<>\d]+(\s*<[^>]+>)?)/g; + + const results = new Array(); + + let paramMatch; + + while ((paramMatch = paramRegex.exec(paramsStr)) !== null) { + + const location = Number(paramMatch[1]); + + const varName = paramMatch[2]; + + const type = paramMatch[3].trim(); + + results.push({ location, varName, type }); + } + + return results; + } + + getFormatSize( format ) { + + switch ( format ) { + + case "float32": return 4; + + case "float32x2": return 8; + + case "float32x3": return 12; + + case "float32x4": return 16; + + case "uint32": return 4; + + case "sint32": return 4; + + default: throw new Error("Unsupported format size for: " + format); + + } + + } + + mapTypeToVertexFormat( type ) { + + // Map WGSL types to GPUVertexFormat strings (example subset) + + switch ( type ) { + + case "f32": return "float32"; + + case "vec2": return "float32x2"; + + case "vec3": return "float32x3"; + + case "vec4": return "float32x4"; + + case "u32": return "uint32"; + + case "i32": return "sint32"; + + default: throw new Error("Unsupported vertex attribute type: " + type); + + } + + } + + _parseVertexAttributes( wgsl ) { + + const locations = this.extractVertexParamsLocations( wgsl ); + + + if (locations.length === 0) { + // No vertex attributes = no vertex buffer descriptors needed + return; + } + + let offset = 0; + + const attributes = new Array(); + + for ( let attr of locations ) { + + const format = this.mapTypeToVertexFormat( attr.type ); + + attributes.push({ + shaderLocation : attr.location, + offset : offset, + format : format + }); + + offset += this.getFormatSize( format ); + + } + + const vertexBufferDescriptor = { + arrayStride: offset, + attributes: attributes, + stepMode: "vertex" + }; + + console.log("vertexBufferDescriptor", vertexBufferDescriptor); + + + this.vertexBuffersDescriptors.set("vertexBuffer0", vertexBufferDescriptor); + + console.log("vertex buffer descriptors updated:", this.vertexBuffersDescriptors); + + } + + _parseBindings( wgsl ) { + + console.log("parse bindings"); + + this.bindings = []; + + // Make the `<...>` part optional by wrapping it in a non-capturing group with `?` + const regex = /@group\((\d+)\)\s+@binding\((\d+)\)\s+var(?:<(\w+)(?:,\s*(\w+))?>)?\s+(\w+)\s*:\s*([^;]+);/g; + + let match; + + while ((match = regex.exec(wgsl)) !== null) { + + const group = Number(match[1]); + + const binding = Number(match[2]); + + const storageClass = match[3]; // e.g. "uniform", "storage", or undefined for textures/samplers + + const access = match[4]; // e.g. "read", "write", or undefined + + const varName = match[5]; + + const varType = match[6]; // WGSL type string like "mat4x4", "texture_2d", "sampler", etc. + + let type; + + if (storageClass === "storage") { + + type = access === "read" ? "read-only-storage" : "storage"; + + } else if (storageClass === "uniform") { + + type = "uniform"; + + } else if (varType.startsWith("texture_")) { + + type = "texture"; + + } else if (varType === "sampler" || varType === "sampler_comparison") { + + type = "sampler"; + + } else { + + type = "uniform"; // fallback + + } + + this.bindings.push({ group, binding, type, varName, varType }); + + } + + console.log("bindings", this.bindings); + + this.bindings.sort((a, b) => a.binding - b.binding); + +} + +_createBindGroupLayoutsFromBindings() { + + const groups = new Map(); + + for ( let i = 0; i < this.bindings.length; i++ ) { + + const b = this.bindings[ i ]; + + if ( !groups.has( b.group ) ) groups.set( b.group, [] ); + + groups.get( b.group ).push( b ); + + } + + const layouts = new Map(); + + for ( const groupEntry of groups.entries() ) { + + const group = groupEntry[ 0 ]; + + const bindings = groupEntry[ 1 ]; + + let visibility; + + if ( this.stage === GPUShaderStage.COMPUTE ) { + + visibility = GPUShaderStage.COMPUTE; + + } else { + + visibility = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT; + + } + + const entries = new Array( bindings.length ); + + for ( let j = 0; j < bindings.length; j++ ) { + + const b = bindings[ j ]; + + let entry; + + if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) { + + entry = { + + binding: b.binding, + + visibility: visibility, + + buffer: { type: b.type } + + }; + + } else if ( b.type === "texture" ) { + + + let viewDimension = "2d"; // default + + if ( b.varType.startsWith("texture_2d_array") ) { + + viewDimension = "2d-array"; + + } else if ( b.varType.startsWith("texture_cube") ) { + + viewDimension = "cube"; + + } else if ( b.varType.startsWith("texture_3d") ) { + + viewDimension = "3d"; + + } + + console.log("viewDimension", viewDimension); + + entry = { + binding: b.binding, + visibility: visibility, + texture: { + sampleType: "float", + viewDimension: viewDimension + } + }; + + } else if ( b.type === "sampler" ) { + + entry = { + + binding: b.binding, + + visibility: visibility, + + sampler: { type: "filtering" } + + }; + + } else { + + entry = { + + binding: b.binding, + + visibility: visibility, + + buffer: { type: "uniform" } + + }; + + } + + entries[ j ] = entry; + + } + + console.log("createBindGroupLayout ", entries); + + const layout = this.device.createBindGroupLayout( { entries } ); + + layouts.set( group, layout ); + + } + + return layouts; + +} + + async _createPipeline() { + + const layoutsArray = this.bindGroupLayouts && this.bindGroupLayouts.size > 0 + + ? Array.from( this.bindGroupLayouts.values() ) + + : [ this.bindGroupLayout ]; + + const shaderModule = this.device.createShaderModule( { + code: this.wgslSource, + + label: this.path, + + } ); + + const info = await shaderModule.compilationInfo(); + + if (info.messages.length > 0) { + for (const msg of info.messages) { + console[msg.type === "error" ? "error" : "warn"]( + `[${msg.lineNum}:${msg.linePos}] ${msg.message}` + ); + } + } + + console.log("this.bindGroupLayouts.values()", this.bindGroupLayouts.values()); + + this.pipeline = this.device.createComputePipeline({ + + layout: this.device.createPipelineLayout({ + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + }), + + compute: { + + module: shaderModule, + + entryPoint: this.entryPoint, + + }, + + } ); + + } + + setBindingGroupLayout( bindGroupLayout ) { + + this._externalBindGroupLayout = bindGroupLayout; + + } + + getBindingGroupLayout( bindingGroupIndex = 0 ) { + + if ( this.bindGroupLayouts ) { + + return this.bindGroupLayouts.get( bindingGroupIndex ); + + } + + return this.bindGroupLayout; + + } + + getTypedArrayByVariableName( name, flatData ) { + + let typedArray; + + const bindingInfo = this.bindings.find( b => b.varName === name ); + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + switch( wgslType ) { + + case "vec3": + + // flatData.push( 0 ); + + break; + + } + + switch ( wgslType ) { + + case "f32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Float32Array( flatData ); + + break; + + case "u32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Uint32Array( flatData ); + + break; + + case "i32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Int32Array( flatData ); + + break; + + default: + + // find struct and then the datatype + typedArray = new Float32Array( flatData ); + //console.error( "Unknown WGSL type in setVariable: " + wgslType ); + } + + //console.log("wgslType", wgslType); + + } + + return typedArray; + + } + + + setVariable( name, dataObject, customBufferUsage ) { + + + if ( dataObject instanceof GPUTexture ) { + + // Handle texture: create and store texture view + const textureView = dataObject.createView(); + + this.textures.set( name, textureView ); + + // Recreate bind groups since texture changed + this.createBindGroups(); + + return; + + } + + if ( dataObject instanceof GPUSampler ) { + + // Handle sampler + this.samplers.set( name, dataObject ); + + // Recreate bind groups since sampler changed + this.createBindGroups(); + + return; + + } + + + const flatData = this.flattenStructure( dataObject ); + + //console.log( "flattenStructure", name, flatData ); + + const sizeBytes = flatData.length * 4; + + let buffer = this.buffers.get( name ); + + if ( !buffer ) { + + throw new Error( "Buffer not found for variable: " + name + ". Buffers must be created in setup. >>" + this.path ); + + } + + if ( sizeBytes > buffer.size ) { + + console.log( `Resizing buffer '${name}' from ${buffer.size} to ${sizeBytes} bytes.` ); + + var usage = this.bufferUsages.get( name ) || ( GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ); + + if ( customBufferUsage ) { + + usage = customBufferUsage; + + } + + const newBuffer = this.device.createBuffer( { + + size : sizeBytes, + + usage : usage, + + label : name + + } ); + + if( !this.isInWorker ) { + + //document.bufferOwners.set( newBuffer, this ); + } + + + this.buffers.set( name, newBuffer ); + + buffer = newBuffer; + + // Recreate bind groups to bind new buffer + + //console.log( "setVariable()" ); + + this.createBindGroups(); + + } + + let typedArray; + + if ( typeof dataObject === "number" ) { + + typedArray = this.getTypedArrayByVariableName( name, flatData ); + + } else if( Array.isArray( dataObject ) ) { + + typedArray = this.getTypedArrayByVariableName( name, flatData ); + + } else if ( dataObject instanceof Uint32Array ) { + + typedArray = dataObject; + + } else if ( typeof dataObject.numTokens === "number" && Number.isInteger( dataObject.numTokens ) ) { + + typedArray = new Uint32Array( [ dataObject.numTokens ] ); + + } else if ( dataObject instanceof Int32Array || dataObject instanceof Float32Array ) { + + typedArray = dataObject; + + } else { + + typedArray = new Float32Array( flatData ); + + } + if( name == "instancePositions") { + console.log("instancePositions", flatData, dataObject, typedArray.byteOffset, typedArray.byteLength, typedArray.buffer); + + } + this.device.queue.writeBuffer( + buffer, + 0, + typedArray.buffer, + typedArray.byteOffset, + typedArray.byteLength + ); + + //console.log( `Buffer data updated: ${name}, byteLength=${typedArray.byteLength}` ); + + } + + flattenStructure( input, output = [] ) { + + if ( Array.isArray( input ) ) { + + for ( const element of input ) { + + this.flattenStructure( element, output ); + + } + + } else if ( typeof input === "object" && input !== null ) { + + for ( const key of Object.keys( input ) ) { + + this.flattenStructure( input[ key ], output ); + + } + + } else if ( typeof input === "number" ) { + + output.push( input ); + + } else if ( typeof input === "boolean" ) { + + output.push( input ? 1 : 0 ); + + } else { + + throw new Error( "Unsupported data type in shader variable: " + typeof input ); + + } + + return output; + + } + + copyBuffersFromShader( sourceShader ) { + + for ( const [ name, buffer ] of sourceShader.buffers.entries() ) { + + this.buffers.set( name, buffer ); + + } + + } + + updateUniforms( data ) { + const arrayBuffer = new ArrayBuffer(12); // 3 x u32 = 3 * 4 bytes + const dataView = new DataView(arrayBuffer); + + dataView.setUint32(0, data.k, true); // offset 0 + dataView.setUint32(4, data.j, true); // offset 4 + dataView.setUint32(8, data.totalCount, true); // offset 8 + + this.device.queue.writeBuffer( + this.uniformBuffer, // your GPUBuffer holding uniforms + 0, // offset in buffer + arrayBuffer + ); + } + + createCommandEncoder( x, y, z ) { + + const commandEncoder = this.device.createCommandEncoder({ + label: this.path + }); + + const passEncoder = commandEncoder.beginComputePass({ + label: this.path + }); + + + if( !this.pipeline ) { + + console.log("Pipeline missing, Maybe you forget to call addStage( ) ", this); + + } + + passEncoder.setPipeline( this.pipeline ); + + + + for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) { + + passEncoder.setBindGroup( bindingGroupIndex, bindGroup ); + + } + + passEncoder.dispatchWorkgroups( x, y, z ); + + passEncoder.end(); + + return commandEncoder; + + } + + async execute( x = 1, y = 1, z = 1 ) { + + const commandEncoder = this.createCommandEncoder( x, y, z ); + + const commandBuffer = commandEncoder.finish(); + + this.device.queue.submit( [ commandBuffer ] ); + + await this.device.queue.onSubmittedWorkDone(); // critical! + + } + + registerBuffer( buffer, name, type ) { + + if( !this.isInWorker ) { + + if( !document.bufferMap[name] ) { + + document.bufferMap[name] = new Array(); + + } + + var bufferHistoryPoint = { buffer: buffer, shader: this, label : buffer.label, type: type }; + + document.bufferMap[name].push( bufferHistoryPoint ); + + } + + } + + getBuffer( name ) { + + var buffer = this.buffers.get( name ); + + this.registerBuffer( buffer, name, "get" ); + + return buffer; + + } + + setBuffer( name, buffer ) { + + this.registerBuffer( buffer, name, "set" ); + + if( !this.isInWorker ) { + + //var bufferOwner = document.bufferOwners.get( buffer ); + + //console.log("setBuffer Origin:", bufferOwner, " this shader", this); + + + } + + + + const bindingInfo = this.bindings.find( b => b.varName === name ); + + if( !bindingInfo ) { + + console.error("No binding to buffer ", name, " in shader ", this ); + + } + + this.buffers.set(name, buffer); + + this.createBindGroups(); // always refresh bindings + + } + + + + async addStage( entryPoint, stage ) { + + if ( !this.path ) { + + throw new Error("Shader path is not set"); + + } + this.entryPoint = entryPoint; + + if( stage != GPUShaderStage.COMPUTE ) { + + this.stage = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT; + + } + + + if ( !this.binded ) { + + if( stage != GPUShaderStage.COMPUTE ) { + + if( this.sampleCount == 1 ) { + + this.depthTexture = this.device.createTexture({ + size: { + width: this.canvas.width, + height: this.canvas.height, + depthOrArrayLayers: 1 + }, + sampleCount: 1, // Must be 1 for non-multisampled rendering + format: "depth24plus", // Or "depth24plus-stencil8" if stencil is needed + usage: GPUTextureUsage.RENDER_ATTACHMENT + }); + + } else { + + this.multisampledColorTexture = this.device.createTexture({ + size: [this.canvas.width, this.canvas.height], + sampleCount: this.sampleCount, + format: "bgra8unorm", + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + this.multisampledDepthTexture = this.device.createTexture({ + size: [this.canvas.width, this.canvas.height], + sampleCount: this.sampleCount, // must match multisample count of color texture and pipeline + format: "depth24plus", // or "depth24plus-stencil8" if you need stencil + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + } + } + + this.wgslSource = await this._loadShaderSource( this.path ); + + this._parseVertexAttributes( this.wgslSource ); + + this._parseBindings( this.wgslSource ); + + if ( this._externalBindGroupLayout ) { + + this.bindGroupLayouts.set( 0, this._externalBindGroupLayout ); + + } else { + + this.bindGroupLayouts = this._createBindGroupLayoutsFromBindings(); + + } + + this._createAllBuffers(); + + this._createDefaultTexturesAndSamplers(); + + await this.createBindGroups(); + + this.binded = true; + + } + + const shaderModule = this.device.createShaderModule( { + + code: this.wgslSource, + + label: this.path + + } ); + + if ( stage === GPUShaderStage.COMPUTE ) { + + console.log("this.bindGroupLayouts", this.bindGroupLayouts); + + console.log("create compute shader pipeline."); + + this.computeStage = { module: shaderModule, entryPoint: this.entryPoint }; + + this.pipeline = this.device.createComputePipeline( { + + layout: this.device.createPipelineLayout( { + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + } ), + + compute: this.computeStage + + } ); + + await this.finalizeShader(); + + } else if ( stage === GPUShaderStage.VERTEX ) { + + //this.vertexStage = { module: shaderModule, entryPoint: this.entryPoint }; +console.log("Vertex buffers descriptors:", Array.from(this.vertexBuffersDescriptors.values())); + + const vertexBuffers = this.vertexBuffersDescriptors.size > 0 + ? [...this.vertexBuffersDescriptors.values()] + : []; + + this.vertexStage = { + module: shaderModule, + entryPoint: this.entryPoint, + buffers: vertexBuffers + }; + + } else if ( stage === GPUShaderStage.FRAGMENT ) { + + this.fragmentStage = { module: shaderModule, entryPoint: this.entryPoint }; + + } + + // If both vertex and fragment stages exist, create render pipeline + if ( this.vertexStage && this.fragmentStage ) { + + this._createAllBuffers(); + + console.log( this.device ); + + console.log("create fragment and vertex shader pipeline."); + + console.log("setup createRenderPipeline", this.vertexStage); + + console.log("this.device.createPipelineLayout", this.bindGroupLayouts.get(0) ); + + + + + this.pipeline = this.device.createRenderPipeline( { + + layout: this.device.createPipelineLayout( { + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + } ), + + vertex: this.vertexStage, + + fragment: { + ...this.fragmentStage, + + targets: [ + { + format: "bgra8unorm", + // No blending here for opaque rendering + blend: undefined, + writeMask: GPUColorWrite.ALL + } + ] // Specify the color target format here + + + }, + multisample: { + + count: this.sampleCount, // number of samples per pixel, typical values: 4 or 8 + + }, + primitive: { + + topology: this.topologyValue, + + frontFace: 'ccw' , + + cullMode: 'none', + + }, + + depthStencil: { + + format: "depth24plus", // <-- Add this to match render pass depthStencil + + depthWriteEnabled: true, + + depthCompare: "less", + + }, + + + + } ); + + this._createAllBuffers(); + + await this.finalizeShader(); + + } + + } + + convertToTypedArrayWithPadding(array, preferredType = null) { + + const isTypedArray = ( + array instanceof Uint16Array || + array instanceof Uint32Array || + array instanceof Float32Array + ); + + let typedArray; + let arrayType = preferredType; + + if (isTypedArray) { + if (preferredType === null) { + if (array instanceof Uint16Array) arrayType = "uint16"; + else if (array instanceof Uint32Array) arrayType = "uint32"; + else if (array instanceof Float32Array) arrayType = "float32"; + } + + const remainder = array.byteLength % 4; + + if (remainder !== 0) { + const elementSize = (arrayType === "uint16") ? 2 : 4; + const paddedByteLength = array.byteLength + (4 - remainder); + const paddedElementCount = paddedByteLength / elementSize; + + if (arrayType === "uint16") { + const paddedArray = new Uint16Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else if (arrayType === "uint32") { + const paddedArray = new Uint32Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else if (arrayType === "float32") { + const paddedArray = new Float32Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else { + throw new Error("Unsupported typed array type for padding"); + } + } else { + typedArray = array; + } + } else if (Array.isArray(array)) { + if (preferredType === null) { + // Default to Float32Array if no preferred type + arrayType = "float32"; + } else { + arrayType = preferredType; + } + + if (arrayType === "uint16") { + typedArray = new Uint16Array(array); + } else if (arrayType === "uint32") { + typedArray = new Uint32Array(array); + } else if (arrayType === "float32") { + typedArray = new Float32Array(array); + } else { + throw new Error("Unsupported preferred type"); + } + + const remainder = typedArray.byteLength % 4; + + if (remainder !== 0) { + const elementSize = (arrayType === "uint16") ? 2 : 4; + const paddedByteLength = typedArray.byteLength + (4 - remainder); + const paddedElementCount = paddedByteLength / elementSize; + + if (arrayType === "uint16") { + const paddedArray = new Uint16Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else if (arrayType === "uint32") { + const paddedArray = new Uint32Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else if (arrayType === "float32") { + const paddedArray = new Float32Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else { + throw new Error("Unsupported typed array type for padding"); + } + } + } else { + throw new Error("Input must be a TypedArray or Array of numbers"); + } + + return { typedArray, arrayType }; + } + + setIndices( indices ) { +console.log("indices", indices); + const isTypedArray = indices instanceof Uint16Array || indices instanceof Uint32Array; + + let typedArray; + let arrayType; + + console.log("isTypedArray", isTypedArray); + + if ( isTypedArray ) { + + arrayType = (indices instanceof Uint16Array) ? "uint16" : "uint32"; + + ({ typedArray } = this.convertToTypedArrayWithPadding( indices, arrayType ) ); + + } else if (Array.isArray(indices)) { + + let maxIndex = 0; + + for (let i = 0; i < indices.length; i++) { + if (indices[i] > maxIndex) maxIndex = indices[i]; + } + + arrayType = (maxIndex < 65536) ? "uint16" : "uint32"; + + + ({ typedArray } = this.convertToTypedArrayWithPadding(indices, arrayType)); + + + } else { + + throw new Error("Indices must be Uint16Array, Uint32Array, or Array of numbers"); + + } + + this.indexFormat = arrayType; + + this.indexBufferData = typedArray; + + if (this.indexBuffer) { + + this.indexBuffer.destroy(); + + } + + this.indexBuffer = this.device.createBuffer({ + size: this.indexBufferData.byteLength, + usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST, + }); + + this.device.queue.writeBuffer( + this.indexBuffer, + 0, + this.indexBufferData.buffer, + this.indexBufferData.byteOffset, + this.indexBufferData.byteLength + ); + + this.indexCount = this.indexBufferData.length; + + + } + + setAttribute( name, array ) { + + if ( this.attributeArrays === undefined ) { + + this.attributeArrays = new Map(); + + } + + const { typedArray, arrayType } = this.convertToTypedArrayWithPadding( array ); + + let attributeSize = 3; // Default size (vec3) + + // this is not proper + if ( name === "uv" ) { + + attributeSize = 2; + + } else if ( name === "position" || name === "normal" ) { + + attributeSize = 3; + + } else { + + // You can add more attribute names and sizes here + + } + + this.attributeArrays.set( name, { typedArray: typedArray, size: attributeSize } ); + + } + + getAllAttributeBuffers() { + + if ( !this.attributeBuffers ) { + + return new Map(); + + } + + return this.attributeBuffers; + + } + + vertexBuffers( passEncoder ) { + + if ( this.attributeArrays === undefined ) { + + console.warn("No attribute arrays to bind."); + + return; + + } + + if ( this.mergedAttributeBuffer === undefined ) { + + const attributeNames = Array.from( this.attributeArrays.keys() ); + + const arrays = attributeNames.map( name => this.attributeArrays.get( name ).typedArray ); + + const attributeSizes = attributeNames.map( name => (name === "uv") ? 2 : 3 ); + + const vertexCount = arrays[0].length / attributeSizes[0]; + + const strideFloats = attributeSizes.reduce( (sum, size) => sum + size, 0 ); + + const mergedArray = new Float32Array( vertexCount * strideFloats ); + + + for ( let i = 0; i < vertexCount; i++ ) { + + let offset = i * strideFloats; + + for ( let attrIndex = 0; attrIndex < arrays.length; attrIndex++ ) { + + const arr = arrays[attrIndex]; + + const size = attributeSizes[attrIndex]; + + const baseIndex = i * size; + + for ( let c = 0; c < size; c++ ) { + + mergedArray[ offset++ ] = arr[ baseIndex + c ]; + + } + + } + + } + + const byteLength = mergedArray.byteLength; + + const buffer = this.device.createBuffer({ + + size : byteLength, + + usage : GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, + + mappedAtCreation : true, + + label : "mergedAttributes" + + }); + + const mapping = new Float32Array( buffer.getMappedRange() ); + + mapping.set( mergedArray ); + + buffer.unmap(); + + this.device.queue.writeBuffer( + + buffer, + + 0, + + mergedArray.buffer, + + mergedArray.byteOffset, + + mergedArray.byteLength + + ); + + this.mergedAttributeBuffer = buffer; + + // Store for pipeline setup + this.attributeStrideFloats = strideFloats; + + this.attributeSizes = attributeSizes; + + this.attributeNames = attributeNames; + + } + + // Bind the merged buffer at slot 0 + + passEncoder.setVertexBuffer( 0, this.mergedAttributeBuffer ); + + } + + setCanvas( canvas ) { + + this.canvas = canvas; + + } + + async renderToCanvas( x, y = 0, z = 0 ) { + + const canvas = this.canvas; + + const context = canvas.getContext("webgpu"); + + const format = navigator.gpu.getPreferredCanvasFormat(); + + const depthTexture = this.device.createTexture({ + size: [this.canvas.width, this.canvas.height, 1], + sampleCount: this.sampleCount, + format: "depth24plus", + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + const commandEncoder = this.device.createCommandEncoder(); + + var renderPassDescriptor; + + if (this.sampleCount === 1) { + + renderPassDescriptor = { + colorAttachments: [{ + view: context.getCurrentTexture().createView(), + loadOp: "clear", + storeOp: "store", + clearValue: { r: 0.3, g: 0.3, b: 0.3, a: 1 } + }], + depthStencilAttachment: { + view: this.depthTexture.createView(), + depthLoadOp: "clear", + depthStoreOp: "store", + depthClearValue: 1.0, + } + }; + + } else { + + renderPassDescriptor = { + colorAttachments: [{ + view: this.multisampledColorTexture.createView(), + resolveTarget: context.getCurrentTexture().createView(), + loadOp: "clear", + storeOp: "store", + clearValue: { r: 0, g: 0, b: 0, a: 1 } + }], + depthStencilAttachment: { + view: this.multisampledDepthTexture.createView(), + depthLoadOp: "clear", + depthStoreOp: "store", + depthClearValue: 1.0, + } + }; + } + + const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor); + + passEncoder.setPipeline(this.pipeline); + + this.vertexBuffers( passEncoder ); + + + for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) { + + passEncoder.setBindGroup( bindingGroupIndex, bindGroup ); + + } + + if( this.indexBuffer ) { + + passEncoder.setIndexBuffer( this.indexBuffer, this.indexFormat ); + + passEncoder.drawIndexed( this.indexCount, y, z ); + + } else { + //console.log(x, y, z); + passEncoder.draw( x, y, z ); + + } + + + + passEncoder.end(); + + this.device.queue.submit([commandEncoder.finish()]); + + } + + async finalizeShader() { + + //await this._createPipeline(); + + await this.createBindGroups(); + + } + + async debugBuffer( bufferName, bufferType ) { + + + + const buffer = this.getBuffer( bufferName ); + + const dataSize = buffer.size; + + const debugReadBuffer = this.device.createBuffer( { + size: dataSize, + usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ + } ); + + const debugEncoder = this.device.createCommandEncoder(); + + console.log( "print buffer", buffer ); + + debugEncoder.copyBufferToBuffer( + buffer, + 0, + debugReadBuffer, + 0, + dataSize + ); + + this.device.queue.submit( [ debugEncoder.finish() ] ); + + await this.device.queue.onSubmittedWorkDone(); + + await debugReadBuffer.mapAsync( GPUMapMode.READ ); + + const copyArrayBuffer = debugReadBuffer.getMappedRange(); + + let result; + + const bindingInfo = this.bindings.find( b => b.varName === bufferName ); + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + //console.log("wgslType", bindingInfo, wgslType, bufferName); + + switch ( wgslType ) { + + case "array": + result = new Float32Array( copyArrayBuffer ); + break; + case "array": + result = new Uint32Array( copyArrayBuffer ); + break; + case "vec3": + + break; + case "vec4": + result = new Float32Array( copyArrayBuffer ); + break; + case "u32": + case "vec2": + case "vec3": + case "vec4": + result = new Uint32Array( copyArrayBuffer ); + break; + case "Tensor": + result = new Float32Array(copyArrayBuffer); + break; + + case "array>": + case "array": + case "array>": + case "array>": + result = new Float32Array( copyArrayBuffer ); + break; + case "vec3": + case "vec4": + result = new Int32Array( copyArrayBuffer ); + break; + + default: + console.error( "Unknown WGSL type in setVariable: " + wgslType, "Using Float32Array as fallback." ); + + result = new Float32Array( copyArrayBuffer ); + } + + // use wgslType here to determine appropriate TypedArray + } + const length = result.length; + var regularArray = new Array(); + + for (var i = 0; i < length; i++) { + + regularArray.push(result[i]); + + } + + + console.log("Debugging array"); + console.log(""); + console.log("Shader: ", this.path); + console.log("Variable Name: ", bufferName); + + console.log( "Buffer Size: ", this.getBuffer(bufferName).size ); + console.log( "TypedArray Size: ", length ); + + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + console.log( "type: ", wgslType ); + + } + + console.log(regularArray); + + debugReadBuffer.unmap(); + + return regularArray; + + } + + +} + + diff --git a/framework/WebGpu_node.js b/framework/WebGpu_node.js new file mode 100644 index 0000000..c7e55bb --- /dev/null +++ b/framework/WebGpu_node.js @@ -0,0 +1,1960 @@ +import { readFileSync } from "node:fs"; + +import gpu from '@kmamal/gpu' + + +if( typeof document != "undefined" ) { + + document.shaders = new Array(); + + document.bufferMap = {}; + +} + +var GPUShaderStage = gpu.GPUShaderStage; +var GPUTextureUsage = gpu.GPUTextureUsage; +var GPUBufferUsage = gpu.GPUBufferUsage; +var GPUColorWrite = gpu.GPUColorWrite; +var GPUTexture = gpu.GPUTexture; +var GPUSampler = gpu.GPUSampler; +var GPUMapMode = gpu.GPUMapMode; + + + +export default class Shader { + + sampleCount = 1; + + device; + + path; + + wgslSource; + + bindGroupLayout; + + bindGroupLayouts = new Map(); + + vertexBuffersDescriptors = new Map(); + + textures = new Map(); + + pipeline; + + bindGroups = new Map(); + + entryPoint = "main"; + + bindings = new Array(); + + buffers = new Map(); + + attributeBuffers ; + + bufferUsages = new Map(); // store usage per buffer to recreate correctly + + _externalBindGroupLayout = null; + + hasLoaded = false; + + stage = GPUShaderStage.COMPUTE; + + isInWorker = typeof document === "undefined"; + + topologyValue = "triangle-strip"; + + binded = false; + + textures = new Map(); + + samplers = new Map(); + + + + constructor( device, path = false ) { + + if( !this.isInWorker ) { + + document.shaders.push( this ); + + } + + console.log("\n\n \n Creating New Shader", path, "\n\n"); + + console.log("device", typeof device); + + this.device = device; + + if( path ) { + + this.path = path; + + } + + } + + set topology(value) { + + if (this.binded) { + + console.error("Define topology before setup", this); + } + + this.topologyValue = value; + + } + + get topology() { + + return this.topologyValue; + + } + + _createAllBuffers() { + + console.log("_createAllBuffers", this.bindings, this.buffers); + + for ( const b of this.bindings ) { + + //console.log("has buffer", b.varName, this.buffers.has( b.varName )); + + if ( this.buffers.has( b.varName ) ) continue; + + let size = 240000 * 3 * 4; + + if ( b.type === "uniform" ) { + + size = 4 * 4; + + } + + let usage; + + if ( b.type === "uniform" ) { + + usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST" ); + + } else if ( b.type === "read-only-storage" ) { + + usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST" ); + + + } else { + + usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; + + console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC" ); + + } + + const buffer = this.device.createBuffer( { + + size, + + usage, + + label: b.varName + + } ); + + if( !this.isInWorker ) { + + //document.bufferOwners.set( buffer, this ); + + } + + this.buffers.set( b.varName, buffer ); + + this.bufferUsages.set( b.varName, usage ); + + //console.log( `Created buffer for '${b.varName}' size=${size} usage=0x${usage.toString(16)}` ); + + } + + } + +_createDefaultTexturesAndSamplers() { + + for ( const b of this.bindings ) { + + if ( b.type === "texture" && !this.textures.has( b.varName ) ) { + + // Detect if the type is a texture array + const isTextureArray = b.varType.startsWith("texture_2d_array"); + + // Create 1x1 texture with 1 layer (depthOrArrayLayers = 1) + const defaultTexture = this.device.createTexture({ + + size: isTextureArray ? [1, 1, 2] : [1, 1, 1], + + format: "rgba8unorm", + + usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST, + + label: b.varName + "_defaultTexture" + + }); + + + console.log("create texture ", isTextureArray, isTextureArray ? "2d-array" : "2d"); + // Create texture view with correct dimension + const defaultTextureView = defaultTexture.createView({ + + dimension: isTextureArray ? "2d-array" : "2d" + + }); + + console.log("create default texture?", b, defaultTextureView, isTextureArray ? "2d-array" : "2d"); + + this.textures.set(b.varName, defaultTextureView); + + } else if ( b.type === "sampler" && !this.samplers.has( b.varName ) ) { + + // Create default sampler + + const defaultSampler = this.device.createSampler( { + + magFilter: "linear", + + minFilter: "linear", + + label: b.varName + "_defaultSampler" + + } ); + + this.samplers.set( b.varName, defaultSampler ); + + } + + } + +} + + +createBindGroups() { + + this.bindGroups = new Map(); + + for ( const [ group, layout ] of this.bindGroupLayouts.entries() ) { + + const bindings = this.bindings.filter( b => b.group === group ); + + const entries = new Array( bindings.length ); + + for ( let i = 0; i < bindings.length; i++ ) { + + const b = bindings[ i ]; + + let resource; + + if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) { + + const gpuBuffer = this.buffers.get( b.varName ); + + if ( !gpuBuffer ) { + + throw new Error( `Buffer missing for ${b.varName} when creating bind group.` ); + + } + + resource = { buffer: gpuBuffer }; + + } else if ( b.type === "texture" ) { + + const gpuTextureView = this.textures.get( b.varName ); + + console.log("this.textures", this.textures); + + if ( !gpuTextureView ) { + + console.warn( `Texture missing for ${b.varName} when creating bind group.` ); + + } else { + + console.log( `Binding texture ${b.varName} with dimension:`, gpuTextureView.dimension ); + + } + + resource = gpuTextureView; + + } else if ( b.type === "sampler" ) { + + const gpuSampler = this.samplers.get( b.varName ); + + if ( !gpuSampler ) { + + //throw new Error( `Sampler missing for ${b.varName} when creating bind group.` ); + + } + + resource = gpuSampler; + + } else { + + throw new Error( `Unknown binding type '${b.type}' for ${b.varName}` ); + + } + + entries[ i ] = { + + binding: b.binding, + + resource: resource + + }; + + } + + const bindGroup = this.device.createBindGroup( { + + layout: layout, + + entries: entries + + } ); + + this.bindGroups.set( group, bindGroup ); + + } + +} + extractEntryPoints( wgslCode ) { + const entryPoints = []; + + // Regex to match lines like: @vertex fn functionName(...) { ... } + const regex = /@(\w+)(?:\s+@\w+(?:\([^)]*\))?)*\s+fn\s+(\w+)\s*\(/g; + + let match; + while ((match = regex.exec(wgslCode)) !== null) { + const stage = match[1]; + const name = match[2]; + + let type; + + switch (stage) { + case "vertex": + type = GPUShaderStage.VERTEX; + break; + case "fragment": + type = GPUShaderStage.FRAGMENT; + break; + case "compute": + type = GPUShaderStage.COMPUTE; + break; + default: + type = "Unknown"; + } + + entryPoints.push({ name: name, type: type }); + } + + return entryPoints; + + } + + async setup( path ) { + + + + this.path = path; + + this.wgslSource = await this._loadShaderSource( this.path ); + //console.log(this.wgslSource); + var entryPoints = this.extractEntryPoints( this.wgslSource ); + + + + + for (var i = 0; i < entryPoints.length; i++) { + + var entryPoint = entryPoints[i]; + + await this.addStage( entryPoint.name, entryPoint.type ); + + } + + //await this.addStage("computeMain", 4 ); + + //console.log( "load shader", this.wgslSource, entryPoints ) + + } + + async _loadShaderSource( pathName ) { + + //const response = await fetch( pathName ); + if ( typeof window === 'undefined' ) { + + const vertexShaderCode = await readFileSync( pathName, 'utf8' ) + + return vertexShaderCode; + + } else { + + const vertexShaderFile = path.join(__dirname, 'vertex.wgsl') + + if ( !response.ok ) throw new Error( `Failed to load shader: ${ pathName }` ); + + return await response.text(); + } + + + + + + + } + + setVertexBuffer( name, dataArray ) { + + const buffer = this.device.createBuffer({ + size: dataArray.byteLength, + usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, + mappedAtCreation: true + }); + + const mapping = new Float32Array( buffer.getMappedRange() ); + + mapping.set( dataArray ); + + buffer.unmap(); + + + + this.buffers.set( name, buffer ); + + } + + extractVertexFunctionParameters(wgsl) { + + let vertexFnIndex = wgsl.indexOf('@vertex'); + + if (vertexFnIndex === -1) return ''; + + let fnIndex = wgsl.indexOf('fn', vertexFnIndex); + + if (fnIndex === -1) return ''; + + let parenStart = wgsl.indexOf('(', fnIndex); + + if (parenStart === -1) return ''; + + let parenCount = 1; + let i = parenStart + 1; + + while (i < wgsl.length && parenCount > 0) { + const c = wgsl[i]; + + if (c === '(') parenCount++; + + else if (c === ')') parenCount--; + + i++; + } + + if (parenCount !== 0) { + // unbalanced parentheses + return ''; + } + + // Extract parameter substring between parenStart+1 and i-1 + let paramsStr = wgsl.substring(parenStart + 1, i - 1); + + return paramsStr.trim(); + } + + extractVertexParamsLocations(wgsl) { + + const paramsStr = this.extractVertexFunctionParameters(wgsl) + + const paramRegex = /@location\s*\(\s*(\d+)\s*\)\s+(\w+)\s*:\s*([\w<>\d]+(\s*<[^>]+>)?)/g; + + const results = new Array(); + + let paramMatch; + + while ((paramMatch = paramRegex.exec(paramsStr)) !== null) { + + const location = Number(paramMatch[1]); + + const varName = paramMatch[2]; + + const type = paramMatch[3].trim(); + + results.push({ location, varName, type }); + } + + return results; + } + + getFormatSize( format ) { + + switch ( format ) { + + case "float32": return 4; + + case "float32x2": return 8; + + case "float32x3": return 12; + + case "float32x4": return 16; + + case "uint32": return 4; + + case "sint32": return 4; + + default: throw new Error("Unsupported format size for: " + format); + + } + + } + + mapTypeToVertexFormat( type ) { + + // Map WGSL types to GPUVertexFormat strings (example subset) + + switch ( type ) { + + case "f32": return "float32"; + + case "vec2": return "float32x2"; + + case "vec3": return "float32x3"; + + case "vec4": return "float32x4"; + + case "u32": return "uint32"; + + case "i32": return "sint32"; + + default: throw new Error("Unsupported vertex attribute type: " + type); + + } + + } + + _parseVertexAttributes( wgsl ) { + + const locations = this.extractVertexParamsLocations( wgsl ); + + let offset = 0; + + const attributes = new Array(); + + for ( let attr of locations ) { + + const format = this.mapTypeToVertexFormat( attr.type ); + + attributes.push({ + shaderLocation : attr.location, + offset : offset, + format : format + }); + + offset += this.getFormatSize( format ); + + } + + const vertexBufferDescriptor = { + arrayStride: offset, + attributes: attributes, + stepMode: "vertex" + }; + + console.log("vertexBufferDescriptor", vertexBufferDescriptor); + + + this.vertexBuffersDescriptors.set("vertexBuffer0", vertexBufferDescriptor); + + console.log("vertex buffer descriptors updated:", this.vertexBuffersDescriptors); + + } + + _parseBindings( wgsl ) { + + console.log("parse bindings"); + + this.bindings = []; + + // Make the `<...>` part optional by wrapping it in a non-capturing group with `?` + const regex = /@group\((\d+)\)\s+@binding\((\d+)\)\s+var(?:<(\w+)(?:,\s*(\w+))?>)?\s+(\w+)\s*:\s*([^;]+);/g; + + let match; + + while ((match = regex.exec(wgsl)) !== null) { + + const group = Number(match[1]); + + const binding = Number(match[2]); + + const storageClass = match[3]; // e.g. "uniform", "storage", or undefined for textures/samplers + + const access = match[4]; // e.g. "read", "write", or undefined + + const varName = match[5]; + + const varType = match[6]; // WGSL type string like "mat4x4", "texture_2d", "sampler", etc. + + let type; + + if (storageClass === "storage") { + + type = access === "read" ? "read-only-storage" : "storage"; + + } else if (storageClass === "uniform") { + + type = "uniform"; + + } else if (varType.startsWith("texture_")) { + + type = "texture"; + + } else if (varType === "sampler" || varType === "sampler_comparison") { + + type = "sampler"; + + } else { + + type = "uniform"; // fallback + + } + + this.bindings.push({ group, binding, type, varName, varType }); + + } + + console.log("bindings", this.bindings); + + this.bindings.sort((a, b) => a.binding - b.binding); + +} + +_createBindGroupLayoutsFromBindings() { + + const groups = new Map(); + + for ( let i = 0; i < this.bindings.length; i++ ) { + + const b = this.bindings[ i ]; + + if ( !groups.has( b.group ) ) groups.set( b.group, [] ); + + groups.get( b.group ).push( b ); + + } + + const layouts = new Map(); + + for ( const groupEntry of groups.entries() ) { + + const group = groupEntry[ 0 ]; + + const bindings = groupEntry[ 1 ]; + + let visibility; + + if ( this.stage === GPUShaderStage.COMPUTE ) { + + visibility = GPUShaderStage.COMPUTE; + + } else { + + visibility = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT; + + } + + const entries = new Array( bindings.length ); + + for ( let j = 0; j < bindings.length; j++ ) { + + const b = bindings[ j ]; + + let entry; + + if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) { + + entry = { + + binding: b.binding, + + visibility: visibility, + + buffer: { type: b.type } + + }; + + } else if ( b.type === "texture" ) { + + + let viewDimension = "2d"; // default + + if ( b.varType.startsWith("texture_2d_array") ) { + + viewDimension = "2d-array"; + + } else if ( b.varType.startsWith("texture_cube") ) { + + viewDimension = "cube"; + + } else if ( b.varType.startsWith("texture_3d") ) { + + viewDimension = "3d"; + + } + + console.log("viewDimension", viewDimension); + + entry = { + binding: b.binding, + visibility: visibility, + texture: { + sampleType: "float", + viewDimension: viewDimension + } + }; + + } else if ( b.type === "sampler" ) { + + entry = { + + binding: b.binding, + + visibility: visibility, + + sampler: { type: "filtering" } + + }; + + } else { + + entry = { + + binding: b.binding, + + visibility: visibility, + + buffer: { type: "uniform" } + + }; + + } + + entries[ j ] = entry; + + } + + console.log("createBindGroupLayout ", entries); + + const layout = this.device.createBindGroupLayout( { entries } ); + + layouts.set( group, layout ); + + } + + return layouts; + +} + + async _createPipeline() { + + const layoutsArray = this.bindGroupLayouts && this.bindGroupLayouts.size > 0 + + ? Array.from( this.bindGroupLayouts.values() ) + + : [ this.bindGroupLayout ]; + + const shaderModule = this.device.createShaderModule( { + code: this.wgslSource, + + label: this.path, + + } ); + + const info = await shaderModule.compilationInfo(); + + if (info.messages.length > 0) { + for (const msg of info.messages) { + console[msg.type === "error" ? "error" : "warn"]( + `[${msg.lineNum}:${msg.linePos}] ${msg.message}` + ); + } + } + + console.log("this.bindGroupLayouts.values()", this.bindGroupLayouts.values()); + + this.pipeline = this.device.createComputePipeline({ + + layout: this.device.createPipelineLayout({ + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + }), + + compute: { + + module: shaderModule, + + entryPoint: this.entryPoint, + + }, + + } ); + + } + + setBindingGroupLayout( bindGroupLayout ) { + + this._externalBindGroupLayout = bindGroupLayout; + + } + + getBindingGroupLayout( bindingGroupIndex = 0 ) { + + if ( this.bindGroupLayouts ) { + + return this.bindGroupLayouts.get( bindingGroupIndex ); + + } + + return this.bindGroupLayout; + + } + + getTypedArrayByVariableName( name, flatData ) { + + let typedArray; + + const bindingInfo = this.bindings.find( b => b.varName === name ); + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + switch( wgslType ) { + + case "vec3": + + // flatData.push( 0 ); + + break; + + } + + switch ( wgslType ) { + + case "f32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Float32Array( flatData ); + + break; + + case "u32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Uint32Array( flatData ); + + break; + + case "i32": + case "vec2": + case "vec3": + case "vec4": + case "array": + + typedArray = new Int32Array( flatData ); + + break; + + default: + + // find struct and then the datatype + typedArray = new Float32Array( flatData ); + //console.error( "Unknown WGSL type in setVariable: " + wgslType ); + } + + //console.log("wgslType", wgslType); + + } + + return typedArray; + + } + + + setVariable( name, dataObject, customBufferUsage ) { + + + if ( dataObject instanceof GPUTexture ) { + + // Handle texture: create and store texture view + const textureView = dataObject.createView(); + + this.textures.set( name, textureView ); + + // Recreate bind groups since texture changed + this.createBindGroups(); + + return; + + } + + if ( dataObject instanceof GPUSampler ) { + + // Handle sampler + this.samplers.set( name, dataObject ); + + // Recreate bind groups since sampler changed + this.createBindGroups(); + + return; + + } + + + const flatData = this.flattenStructure( dataObject ); + + //console.log( "flattenStructure", name, flatData ); + + const sizeBytes = flatData.length * 4; + + let buffer = this.buffers.get( name ); + + if ( !buffer ) { + + throw new Error( "Buffer not found for variable: " + name + ". Buffers must be created in setup. >>" + this.path ); + + } + + if ( sizeBytes > buffer.size ) { + + console.log( `Resizing buffer '${name}' from ${buffer.size} to ${sizeBytes} bytes.` ); + + var usage = this.bufferUsages.get( name ) || ( GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ); + + if ( customBufferUsage ) { + + usage = customBufferUsage; + + } + + const newBuffer = this.device.createBuffer( { + + size : sizeBytes, + + usage : usage, + + label : name + + } ); + + if( !this.isInWorker ) { + + //document.bufferOwners.set( newBuffer, this ); + } + + + this.buffers.set( name, newBuffer ); + + buffer = newBuffer; + + // Recreate bind groups to bind new buffer + + //console.log( "setVariable()" ); + + this.createBindGroups(); + + } + + let typedArray; + + if ( typeof dataObject === "number" ) { + + typedArray = this.getTypedArrayByVariableName( name, flatData ); + + } else if( Array.isArray( dataObject ) ) { + + typedArray = this.getTypedArrayByVariableName( name, flatData ); + + } else if ( dataObject instanceof Uint32Array ) { + + typedArray = dataObject; + + } else if ( typeof dataObject.numTokens === "number" && Number.isInteger( dataObject.numTokens ) ) { + + typedArray = new Uint32Array( [ dataObject.numTokens ] ); + + } else if ( dataObject instanceof Int32Array || dataObject instanceof Float32Array ) { + + typedArray = dataObject; + + } else { + + typedArray = new Float32Array( flatData ); + + } + if( name == "instancePositions") { + console.log("instancePositions", flatData, dataObject, typedArray.byteOffset, typedArray.byteLength, typedArray.buffer); + + } + this.device.queue.writeBuffer( + buffer, + 0, + typedArray.buffer, + typedArray.byteOffset, + typedArray.byteLength + ); + + //console.log( `Buffer data updated: ${name}, byteLength=${typedArray.byteLength}` ); + + } + + flattenStructure( input, output = [] ) { + + if ( Array.isArray( input ) ) { + + for ( const element of input ) { + + this.flattenStructure( element, output ); + + } + + } else if ( typeof input === "object" && input !== null ) { + + for ( const key of Object.keys( input ) ) { + + this.flattenStructure( input[ key ], output ); + + } + + } else if ( typeof input === "number" ) { + + output.push( input ); + + } else if ( typeof input === "boolean" ) { + + output.push( input ? 1 : 0 ); + + } else { + + throw new Error( "Unsupported data type in shader variable: " + typeof input ); + + } + + return output; + + } + + copyBuffersFromShader( sourceShader ) { + + for ( const [ name, buffer ] of sourceShader.buffers.entries() ) { + + this.buffers.set( name, buffer ); + + } + + } + + updateUniforms( data ) { + const arrayBuffer = new ArrayBuffer(12); // 3 x u32 = 3 * 4 bytes + const dataView = new DataView(arrayBuffer); + + dataView.setUint32(0, data.k, true); // offset 0 + dataView.setUint32(4, data.j, true); // offset 4 + dataView.setUint32(8, data.totalCount, true); // offset 8 + + this.device.queue.writeBuffer( + this.uniformBuffer, // your GPUBuffer holding uniforms + 0, // offset in buffer + arrayBuffer + ); + } + + createCommandEncoder( x, y, z ) { + + const commandEncoder = this.device.createCommandEncoder({ + label: this.path + }); + + const passEncoder = commandEncoder.beginComputePass({ + label: this.path + }); + + + if( !this.pipeline ) { + + console.log("Pipeline missing, Maybe you forget to call addStage( ) ", this); + + } + + passEncoder.setPipeline( this.pipeline ); + + + + for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) { + + passEncoder.setBindGroup( bindingGroupIndex, bindGroup ); + + } + + passEncoder.dispatchWorkgroups( x, y, z ); + + passEncoder.end(); + + return commandEncoder; + + } + + async execute( x = 1, y = 1, z = 1 ) { + + const commandEncoder = this.createCommandEncoder( x, y, z ); + + const commandBuffer = commandEncoder.finish(); + + this.device.queue.submit( [ commandBuffer ] ); + + await this.device.queue.onSubmittedWorkDone(); // critical! + + } + + registerBuffer( buffer, name, type ) { + + if( !this.isInWorker ) { + + if( !document.bufferMap[name] ) { + + document.bufferMap[name] = new Array(); + + } + + var bufferHistoryPoint = { buffer: buffer, shader: this, label : buffer.label, type: type }; + + document.bufferMap[name].push( bufferHistoryPoint ); + + } + + } + + getBuffer( name ) { + + var buffer = this.buffers.get( name ); + + this.registerBuffer( buffer, name, "get" ); + + return buffer; + + } + + setBuffer( name, buffer ) { + + this.registerBuffer( buffer, name, "set" ); + + if( !this.isInWorker ) { + + //var bufferOwner = document.bufferOwners.get( buffer ); + + //console.log("setBuffer Origin:", bufferOwner, " this shader", this); + + + } + + + + const bindingInfo = this.bindings.find( b => b.varName === name ); + + if( !bindingInfo ) { + + console.error("No binding to buffer ", name, " in shader ", this ); + + } + + this.buffers.set(name, buffer); + + this.createBindGroups(); // always refresh bindings + + } + + + + async addStage( entryPoint, stage ) { + + if ( !this.path ) { + + throw new Error("Shader path is not set"); + + } + this.entryPoint = entryPoint; + + if( stage != GPUShaderStage.COMPUTE ) { + + this.stage = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT; + + } + + if ( !this.binded ) { + + if( stage != GPUShaderStage.COMPUTE ) { + + if( this.sampleCount == 1 ) { + + this.depthTexture = this.device.createTexture({ + size: { + width: this.canvas.width, + height: this.canvas.height, + depthOrArrayLayers: 1 + }, + sampleCount: 1, // Must be 1 for non-multisampled rendering + format: "depth24plus", // Or "depth24plus-stencil8" if stencil is needed + usage: GPUTextureUsage.RENDER_ATTACHMENT + }); + + } else { + + this.multisampledColorTexture = this.device.createTexture({ + size: [this.canvas.width, this.canvas.height], + sampleCount: this.sampleCount, + format: "bgra8unorm", + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + this.multisampledDepthTexture = this.device.createTexture({ + size: [this.canvas.width, this.canvas.height], + sampleCount: this.sampleCount, // must match multisample count of color texture and pipeline + format: "depth24plus", // or "depth24plus-stencil8" if you need stencil + usage: GPUTextureUsage.RENDER_ATTACHMENT, + }); + + } + } + + this.wgslSource = await this._loadShaderSource( this.path ); + + this._parseVertexAttributes( this.wgslSource ); + + this._parseBindings( this.wgslSource ); + + if ( this._externalBindGroupLayout ) { + + this.bindGroupLayouts.set( 0, this._externalBindGroupLayout ); + + } else { + + this.bindGroupLayouts = this._createBindGroupLayoutsFromBindings(); + + } + + this._createAllBuffers(); + + this._createDefaultTexturesAndSamplers(); + + await this.createBindGroups(); + + this.binded = true; + + } + + const shaderModule = this.device.createShaderModule( { + + code: this.wgslSource, + + label: this.path + + } ); + + if ( stage === GPUShaderStage.COMPUTE ) { + + console.log("this.bindGroupLayouts", this.bindGroupLayouts); + + console.log("create compute shader pipeline."); + + this.computeStage = { module: shaderModule, entryPoint: this.entryPoint }; + + this.pipeline = this.device.createComputePipeline( { + + layout: this.device.createPipelineLayout( { + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + } ), + + compute: this.computeStage + + } ); + + await this.finalizeShader(); + + } else if ( stage === GPUShaderStage.VERTEX ) { + + //this.vertexStage = { module: shaderModule, entryPoint: this.entryPoint }; +console.log("Vertex buffers descriptors:", Array.from(this.vertexBuffersDescriptors.values())); + this.vertexStage = { + module: shaderModule, + entryPoint: this.entryPoint, + buffers: Array.from( this.vertexBuffersDescriptors.values() ) + }; + + } else if ( stage === GPUShaderStage.FRAGMENT ) { + + this.fragmentStage = { module: shaderModule, entryPoint: this.entryPoint }; + + } + + // If both vertex and fragment stages exist, create render pipeline + if ( this.vertexStage && this.fragmentStage ) { + + this._createAllBuffers(); + + console.log( this.device ); + + console.log("create fragment and vertex shader pipeline."); + + console.log("setup createRenderPipeline", this.vertexStage); + + console.log("this.device.createPipelineLayout", this.bindGroupLayouts.get(0) ); + + this.pipeline = this.device.createRenderPipeline( { + + layout: this.device.createPipelineLayout( { + + bindGroupLayouts: Array.from( this.bindGroupLayouts.values() ) + + } ), + + vertex: this.vertexStage, + + fragment: { + ...this.fragmentStage, + + targets: [ + { + //format: "bgra8unorm", + format: this.canvas.getContext().getPreferredFormat(), + + // No blending here for opaque rendering + blend: undefined, + writeMask: GPUColorWrite.ALL + } + ] // Specify the color target format here + + + }, + multisample: { + + count: this.sampleCount, // number of samples per pixel, typical values: 4 or 8 + + }, + primitive: { + + topology: this.topologyValue, + + frontFace: 'ccw' , + + cullMode: 'none', + + }, + + depthStencil: { + + format: "depth24plus", // <-- Add this to match render pass depthStencil + + depthWriteEnabled: true, + + depthCompare: "less", + + }, + + + + } ); + + this._createAllBuffers(); + + await this.finalizeShader(); + + } + + } + + convertToTypedArrayWithPadding(array, preferredType = null) { + + const isTypedArray = ( + array instanceof Uint16Array || + array instanceof Uint32Array || + array instanceof Float32Array + ); + + let typedArray; + let arrayType = preferredType; + + if (isTypedArray) { + if (preferredType === null) { + if (array instanceof Uint16Array) arrayType = "uint16"; + else if (array instanceof Uint32Array) arrayType = "uint32"; + else if (array instanceof Float32Array) arrayType = "float32"; + } + + const remainder = array.byteLength % 4; + + if (remainder !== 0) { + const elementSize = (arrayType === "uint16") ? 2 : 4; + const paddedByteLength = array.byteLength + (4 - remainder); + const paddedElementCount = paddedByteLength / elementSize; + + if (arrayType === "uint16") { + const paddedArray = new Uint16Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else if (arrayType === "uint32") { + const paddedArray = new Uint32Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else if (arrayType === "float32") { + const paddedArray = new Float32Array(paddedElementCount); + paddedArray.set(array); + typedArray = paddedArray; + } else { + throw new Error("Unsupported typed array type for padding"); + } + } else { + typedArray = array; + } + } else if (Array.isArray(array)) { + if (preferredType === null) { + // Default to Float32Array if no preferred type + arrayType = "float32"; + } else { + arrayType = preferredType; + } + + if (arrayType === "uint16") { + typedArray = new Uint16Array(array); + } else if (arrayType === "uint32") { + typedArray = new Uint32Array(array); + } else if (arrayType === "float32") { + typedArray = new Float32Array(array); + } else { + throw new Error("Unsupported preferred type"); + } + + const remainder = typedArray.byteLength % 4; + + if (remainder !== 0) { + const elementSize = (arrayType === "uint16") ? 2 : 4; + const paddedByteLength = typedArray.byteLength + (4 - remainder); + const paddedElementCount = paddedByteLength / elementSize; + + if (arrayType === "uint16") { + const paddedArray = new Uint16Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else if (arrayType === "uint32") { + const paddedArray = new Uint32Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else if (arrayType === "float32") { + const paddedArray = new Float32Array(paddedElementCount); + paddedArray.set(typedArray); + typedArray = paddedArray; + } else { + throw new Error("Unsupported typed array type for padding"); + } + } + } else { + throw new Error("Input must be a TypedArray or Array of numbers"); + } + + return { typedArray, arrayType }; + } + + setIndices( indices ) { +console.log("indices", indices); + const isTypedArray = indices instanceof Uint16Array || indices instanceof Uint32Array; + + let typedArray; + let arrayType; + + console.log("isTypedArray", isTypedArray); + + if ( isTypedArray ) { + + arrayType = (indices instanceof Uint16Array) ? "uint16" : "uint32"; + + ({ typedArray } = this.convertToTypedArrayWithPadding( indices, arrayType ) ); + + } else if (Array.isArray(indices)) { + + let maxIndex = 0; + + for (let i = 0; i < indices.length; i++) { + if (indices[i] > maxIndex) maxIndex = indices[i]; + } + + arrayType = (maxIndex < 65536) ? "uint16" : "uint32"; + + + ({ typedArray } = this.convertToTypedArrayWithPadding(indices, arrayType)); + + + } else { + + throw new Error("Indices must be Uint16Array, Uint32Array, or Array of numbers"); + + } + + this.indexFormat = arrayType; + + this.indexBufferData = typedArray; + + if (this.indexBuffer) { + + this.indexBuffer.destroy(); + + } + + this.indexBuffer = this.device.createBuffer({ + size: this.indexBufferData.byteLength, + usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST, + }); + + this.device.queue.writeBuffer( + this.indexBuffer, + 0, + this.indexBufferData.buffer, + this.indexBufferData.byteOffset, + this.indexBufferData.byteLength + ); + + this.indexCount = this.indexBufferData.length; + + + } + + setAttribute( name, array ) { + + if ( this.attributeArrays === undefined ) { + + this.attributeArrays = new Map(); + + } + + const { typedArray, arrayType } = this.convertToTypedArrayWithPadding( array ); + + let attributeSize = 3; // Default size (vec3) + + // this is not proper + if ( name === "uv" ) { + + attributeSize = 2; + + } else if ( name === "position" || name === "normal" ) { + + attributeSize = 3; + + } else { + + // You can add more attribute names and sizes here + + } + + this.attributeArrays.set( name, { typedArray: typedArray, size: attributeSize } ); + + } + + getAllAttributeBuffers() { + + if ( !this.attributeBuffers ) { + + return new Map(); + + } + + return this.attributeBuffers; + + } + + vertexBuffers( passEncoder ) { + + if ( this.attributeArrays === undefined ) { + + console.warn("No attribute arrays to bind."); + + return; + + } + + if ( this.mergedAttributeBuffer === undefined ) { + + const attributeNames = Array.from( this.attributeArrays.keys() ); + + const arrays = attributeNames.map( name => this.attributeArrays.get( name ).typedArray ); + + const attributeSizes = attributeNames.map( name => (name === "uv") ? 2 : 3 ); + + const vertexCount = arrays[0].length / attributeSizes[0]; + + const strideFloats = attributeSizes.reduce( (sum, size) => sum + size, 0 ); + + const mergedArray = new Float32Array( vertexCount * strideFloats ); + + + for ( let i = 0; i < vertexCount; i++ ) { + + let offset = i * strideFloats; + + for ( let attrIndex = 0; attrIndex < arrays.length; attrIndex++ ) { + + const arr = arrays[attrIndex]; + + const size = attributeSizes[attrIndex]; + + const baseIndex = i * size; + + for ( let c = 0; c < size; c++ ) { + + mergedArray[ offset++ ] = arr[ baseIndex + c ]; + + } + + } + + } + + const byteLength = mergedArray.byteLength; + + const buffer = this.device.createBuffer({ + + size : byteLength, + + usage : GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, + + mappedAtCreation : true, + + label : "mergedAttributes" + + }); + + const mapping = new Float32Array( buffer.getMappedRange() ); + + mapping.set( mergedArray ); + + buffer.unmap(); + + this.device.queue.writeBuffer( + + buffer, + + 0, + + mergedArray.buffer, + + mergedArray.byteOffset, + + mergedArray.byteLength + + ); + + this.mergedAttributeBuffer = buffer; + + // Store for pipeline setup + this.attributeStrideFloats = strideFloats; + + this.attributeSizes = attributeSizes; + + this.attributeNames = attributeNames; + + } + + // Bind the merged buffer at slot 0 + + passEncoder.setVertexBuffer( 0, this.mergedAttributeBuffer ); + + } + + setCanvas( canvas ) { + + this.canvas = canvas; + + } + + async renderToCanvas( x, y = 0, z = 0, test ) { + + const canvas = this.canvas; + + const context = canvas.getContext("webgpu"); + + + const commandEncoder = this.device.createCommandEncoder() + + var renderPassDescriptor = { + colorAttachments: [{ + view: context.getCurrentTextureView(), + loadOp: "clear", + storeOp: "store", + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1 } + }] + , + depthStencilAttachment: { + view: this.depthTexture.createView(), + depthLoadOp: "clear", + depthStoreOp: "store", + depthClearValue: 1.0, + } + }; + + const renderPass = commandEncoder.beginRenderPass(renderPassDescriptor) + + renderPass.setPipeline(this.pipeline) + //renderPass.setViewport(0, 0, canvas.width, canvas.height, 0, 1) + //renderPass.setScissorRect(0, 0, canvas.width, canvas.height) + + + //renderPass.setVertexBuffer(0, positionBuffer) + //renderPass.setVertexBuffer(1, colorBuffer) + //renderPass.setIndexBuffer(indexBuffer, 'uint16') + //renderPass.drawIndexed(3) + this.vertexBuffers( renderPass ); + + + for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) { + + renderPass.setBindGroup( bindingGroupIndex, bindGroup ); + + } + + if( this.indexBuffer ) { + + renderPass.setIndexBuffer( this.indexBuffer, this.indexFormat ); + + renderPass.drawIndexed( this.indexCount, y, z ); + + } else { + //console.log(x, y, z); + renderPass.draw( x, y, z ); + + } + + + renderPass.end() + + + this.device.queue.submit([ commandEncoder.finish() ]) + + + return; + +/* + + const commandEncoder = this.device.createCommandEncoder(); + + var renderPassDescriptor; + + + if (this.sampleCount === 1) { + + renderPassDescriptor = { + colorAttachments: [{ + view: context.getCurrentTexture().createView(), + loadOp: "clear", + storeOp: "store", + clearValue: { r: 0.3, g: 0.3, b: 0.3, a: 1 } + }], + depthStencilAttachment: { + view: this.depthTexture.createView(), + depthLoadOp: "clear", + depthStoreOp: "store", + depthClearValue: 1.0, + } + }; + + } else { + + renderPassDescriptor = { + colorAttachments: [{ + view: this.multisampledColorTexture.createView(), + resolveTarget: context.getCurrentTexture().createView(), + loadOp: "clear", + storeOp: "store", + clearValue: { r: 0, g: 0, b: 0, a: 1 } + }], + depthStencilAttachment: { + view: this.multisampledDepthTexture.createView(), + depthLoadOp: "clear", + depthStoreOp: "store", + depthClearValue: 1.0, + } + }; + } + + const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor); + + passEncoder.setPipeline(this.pipeline); + + this.vertexBuffers( passEncoder ); + + + for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) { + + passEncoder.setBindGroup( bindingGroupIndex, bindGroup ); + + } + + if( this.indexBuffer ) { + + passEncoder.setIndexBuffer( this.indexBuffer, this.indexFormat ); + + passEncoder.drawIndexed( this.indexCount, y, z ); + + } else { + //console.log(x, y, z); + passEncoder.draw( x, y, z ); + + } + + + + passEncoder.end(); + + this.device.queue.submit([commandEncoder.finish()]); +*/ + } + + async finalizeShader() { + + //await this._createPipeline(); + + await this.createBindGroups(); + + } + + async debugBuffer( bufferName, bufferType ) { + + + + const buffer = this.getBuffer( bufferName ); + + const dataSize = buffer.size; + + const debugReadBuffer = this.device.createBuffer( { + size: dataSize, + usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ + } ); + + const debugEncoder = this.device.createCommandEncoder(); + + console.log( "print buffer", buffer ); + + debugEncoder.copyBufferToBuffer( + buffer, + 0, + debugReadBuffer, + 0, + dataSize + ); + + this.device.queue.submit( [ debugEncoder.finish() ] ); + + await this.device.queue.onSubmittedWorkDone(); + + await debugReadBuffer.mapAsync( GPUMapMode.READ ); + + const copyArrayBuffer = debugReadBuffer.getMappedRange(); + + let result; + + const bindingInfo = this.bindings.find( b => b.varName === bufferName ); + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + //console.log("wgslType", bindingInfo, wgslType, bufferName); + + switch ( wgslType ) { + + case "array": + result = new Float32Array( copyArrayBuffer ); + break; + case "array": + result = new Uint32Array( copyArrayBuffer ); + break; + case "vec3": + + break; + case "vec4": + result = new Float32Array( copyArrayBuffer ); + break; + case "u32": + case "vec2": + case "vec3": + case "vec4": + result = new Uint32Array( copyArrayBuffer ); + break; + + + case "array>": + case "array": + case "array>": + case "array>": + result = new Float32Array( copyArrayBuffer ); + break; + case "vec3": + case "vec4": + result = new Int32Array( copyArrayBuffer ); + break; + + default: + console.error( "Unknown WGSL type in setVariable: " + wgslType, "Using Float32Array as fallback." ); + + result = new Float32Array( copyArrayBuffer ); + } + + // use wgslType here to determine appropriate TypedArray + } + const length = result.length; + var regularArray = new Array(); + + for (var i = 0; i < length; i++) { + + regularArray.push(result[i]); + + } + + + console.log("Debugging array"); + console.log(""); + console.log("Shader: ", this.path); + console.log("Variable Name: ", bufferName); + + console.log( "Buffer Size: ", this.getBuffer(bufferName).size ); + console.log( "TypedArray Size: ", length ); + + + if ( bindingInfo ) { + + const wgslType = bindingInfo.varType; + + console.log( "type: ", wgslType ); + + } + + console.log(regularArray); + + debugReadBuffer.unmap(); + + return regularArray; + + } + + +} + + diff --git a/framework/eventManager.js b/framework/eventManager.js new file mode 100644 index 0000000..882febc --- /dev/null +++ b/framework/eventManager.js @@ -0,0 +1,113 @@ +// eventManager.js + +export default class EventManager { + + isDragging = false; + + lastX = 0; + + lastY = 0; + + camera; + + canvas; + + setCanvas( canvas ) { + + this.canvas = canvas; + + + //this.registerEventListeners(); + + //this.handleResize(); + + } + + setup( canvas, camera ) { + + this.canvas = canvas; + + this.camera = camera; + + //this.registerEventListeners(); + + //this.handleResize(); + + } + + registerEventListeners() { + + this.canvas.addEventListener( "mousedown", this.onMouseDown.bind(this) ); + + this.canvas.addEventListener( "mouseup", this.onMouseUp.bind(this) ); + + this.canvas.addEventListener( "mouseleave", this.onMouseLeave.bind(this) ); + + this.canvas.addEventListener( "mousemove", this.onMouseMove.bind(this) ); + + this.canvas.addEventListener( "wheel", this.onWheel.bind(this), { passive: false } ); + + } + + resize( event ) { + + this.canvas.width = event.width; + + this.canvas.height = event.height; + + //this.canvas.width = window.innerWidth; + + //this.canvas.height = window.innerHeight; + + } + + mousedown( event ) { + + console.log("mouseDownHandler"); + + this.isDragging = true; + + this.lastX = event.clientX; + + this.lastY = event.clientY; + + } + + mouseup( event ) { + + this.isDragging = false; + + } + + mouseleave( event ) { + + this.isDragging = false; + + } + + mousemove( event ) { + + if ( !this.isDragging ) return; + + const deltaX = ( event.clientX - this.lastX ) * 0.005; + + const deltaY = ( event.clientY - this.lastY ) * 0.005; + + this.camera.rotate( deltaX, -deltaY ); + + this.lastX = event.clientX; + + this.lastY = event.clientY; + + } + + wheel( event ) { + + + const delta = event.deltaY * 0.01; + + this.camera.zoom( delta ); + + } + +} diff --git a/framework/eventManager_node.js b/framework/eventManager_node.js new file mode 100644 index 0000000..73522b7 --- /dev/null +++ b/framework/eventManager_node.js @@ -0,0 +1,214 @@ +// eventManager.js + +import sdl from '@kmamal/sdl' + +var isNode = true; + +if ( typeof window === 'undefined' ) { + + //isNode = false; + +} + +console.log("isNode", isNode); + +export default class EventManager { + + isDragging = false; + + lastX = 0; + + lastY = 0; + + camera; + + canvas; + + setCanvas( canvas ) { + + this.canvas = canvas; + + + //this.registerEventListeners(); + + //this.handleResize(); + + } + + setup( canvas, camera ) { + + this.canvas = canvas; + + this.camera = camera; + + //this.registerEventListeners(); + + //this.handleResize(); + + } + + registerEventListeners() { + + this.canvas.addEventListener( "mousedown", this.onMouseDown.bind(this) ); + + this.canvas.addEventListener( "mouseup", this.onMouseUp.bind(this) ); + + this.canvas.addEventListener( "mouseleave", this.onMouseLeave.bind(this) ); + + this.canvas.addEventListener( "mousemove", this.onMouseMove.bind(this) ); + + this.canvas.addEventListener( "wheel", this.onWheel.bind(this), { passive: false } ); + + } + + registerEventListenersNode() { + + var that = this; + + this.canvas.on('mouseMove', function( event ) { + + that.mousemove( event ) + + }); + + this.canvas.on('mouseButtonDown', function( event ) { + + that.mousedown( event ) + + }); + + this.canvas.on('mouseButtonUp', function( event ) { + + that.mouseup( event ) + + }); + + +/* + this.canvas.on( "mouseButtonDown", this.onMouseDown.bind(this) ); + + this.canvas.on( "mouseButtonUp", this.onMouseUp.bind(this) ); + + //this.canvas.on( "mouseleave", this.onMouseLeave.bind(this) ); + + this.canvas.on( "mouseMove", this.onMouseMove.bind(this) ); + + this.canvas.on( "mouseWheel", this.onWheel.bind(this), { passive: false } ); +*/ + } + + resize( event ) { + + this.canvas.width = event.width; + + this.canvas.height = event.height; + + //this.canvas.width = window.innerWidth; + + //this.canvas.height = window.innerHeight; + + } + + mousedown( event ) { + + + + this.isDragging = true; + + if( isNode ) { + + var mouseX = event.x; + + var mouseY = event.y; + + } else { + + var mouseX = event.clientX; + + var mouseY = event.clientY; + + } + + //console.log("mouseDownHandler", mouseX, mouseY); + + if( isNode ) { + + this.lastX = mouseX; + + this.lastY = mouseY; + + } else { + + this.lastX = mouseX; + + this.lastY = mouseY; + + } + + + + + + } + + mouseup( event ) { + + this.isDragging = false; + + } + + mouseleave( event ) { + + this.isDragging = false; + + } + + mousemove( event ) { + + if( isNode ) { + + var mouseX = event.x; + + var mouseY = event.y; + + } else { + + var mouseX = event.clientX; + + var mouseY = event.clientY; + + } + + + + + if ( !this.isDragging ) return; + + const deltaX = ( mouseX - this.lastX ) * 0.005; + + const deltaY = ( mouseY - this.lastY ) * 0.005; + + //console.log("mousemove", mouseX, mouseY); + + this.camera.rotate( deltaX, -deltaY ); + + this.lastX = mouseX; + + this.lastY = mouseY; + + + + + + } + + wheel( event ) { + + + const delta = event.deltaY * 0.01; + + this.camera.zoom( delta ); + + } + +} diff --git a/framework/package.json b/framework/package.json new file mode 100644 index 0000000..4d93719 --- /dev/null +++ b/framework/package.json @@ -0,0 +1,6 @@ +{ + "type": "module", + "dependencies": { + + } +} diff --git a/index.html b/index.html new file mode 100644 index 0000000..cec2867 --- /dev/null +++ b/index.html @@ -0,0 +1,22 @@ + + + + + WebGPU MNIST Training (Batch 64) + + + +

WebGPU MNIST Training — Batch 64

+ +
+
+ + + diff --git a/mnist_loader.js b/mnist_loader.js new file mode 100644 index 0000000..43b3af5 --- /dev/null +++ b/mnist_loader.js @@ -0,0 +1,27 @@ +export async function loadMNIST(device, sampleCount = 1000) { + async function loadFile(path) { + const res = await fetch(path); + return new Uint8Array(await res.arrayBuffer()); + } + + const imgRaw = await loadFile("../../models/train-images-idx3-ubyte"); + const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte"); + + const header = 16; + const fullCount = lblRaw.length - 8; + const count = Math.min(sampleCount, fullCount); + + const images = new Float32Array(count * 28 * 28); + const labels = new Uint32Array(count); + + for (let i = 0; i < count; i++) { + labels[i] = lblRaw[8 + i]; + const src = header + i * 784; + const dst = i * 784; + for (let j = 0; j < 784; j++) { + images[dst + j] = imgRaw[src + j] / 255; + } + } + + return { images, labels, count }; +} diff --git a/models/t10k-images-idx3-ubyte b/models/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/models/t10k-images-idx3-ubyte differ diff --git a/models/t10k-labels-idx1-ubyte b/models/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/models/t10k-labels-idx1-ubyte differ diff --git a/models/train-images-idx3-ubyte b/models/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/models/train-images-idx3-ubyte differ diff --git a/models/train-labels-idx1-ubyte b/models/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/models/train-labels-idx1-ubyte differ diff --git a/package.json b/package.json new file mode 100644 index 0000000..c45d947 --- /dev/null +++ b/package.json @@ -0,0 +1,3 @@ +{ + "type":"module" +} \ No newline at end of file diff --git a/server.js b/server.js new file mode 100644 index 0000000..07da6e7 --- /dev/null +++ b/server.js @@ -0,0 +1,168 @@ +import http from "http"; + +import { readdir } from "fs/promises"; +import { stat } from "fs/promises"; +import { readFile } from "fs/promises"; +import { join } from "path"; +import { dirname } from "path"; +import { fileURLToPath } from "url"; + + +class App +{ + + + constructor( ) + { + + const selfPath = fileURLToPath( import.meta.url ); + + this.rootPath = dirname( selfPath ); + this.httpServer = null; + + } + + + async start( ) + { + + this.httpServer = http.createServer( this.handleRequest.bind( this ) ); + + this.httpServer.listen( 2020 ); + + console.log("Server started on port http://localhost:2020/"); + + } + + + async handleRequest( req, res ) + { + + const requestedPath = decodeURI( req.url ); + const fullPath = join( this.rootPath, requestedPath ); + + const exists = await this.checkFileExists( fullPath ); + + if ( !exists ) + { + + res.statusCode = 404; + res.end( "Not Found" ); + return; + + } + + const stats = await stat( fullPath ); + + if ( stats.isDirectory( ) ) + { + + const indexPath = join( fullPath, "index.html" ); + const indexExists = await this.checkFileExists( indexPath ); + + if ( indexExists ) + { + + await this.sendFile( indexPath, res ); + return; + + } + + await this.sendDirectoryListing( fullPath, requestedPath, res ); + return; + + } + + await this.sendFile( fullPath, res ); + + } + + + async sendFile( path, res ) + { + + const contentType = this.getContentType( path ); + const fileData = await readFile( path ); + + res.setHeader( "Content-Type", contentType ); + res.statusCode = 200; + + res.end( fileData ); + + } + + + async sendDirectoryListing( dirPath, urlPath, res ) + { + + const entries = await readdir( dirPath, { withFileTypes : true } ); + + let html = "

Index of " + urlPath + "

    "; + + let i = 0; + + while ( i < entries.length ) + { + + const e = entries[ i ].name; + const link = urlPath.endsWith( "/" ) + ? urlPath + e + : urlPath + "/" + e; + + html = html + "
  • " + e + "
  • "; + + i = i + 1; + + } + + html = html + "
"; + + res.setHeader( "Content-Type", "text/html" ); + res.statusCode = 200; + + res.end( html ); + + } + + + async checkFileExists( path ) + { + + const exists = await stat( path ) + .then( function( ) { return true; } ) + .catch( function( ) { return false; } ); + + return exists; + + } + + + getContentType( path ) + { + + const lower = path.toLowerCase( ); + + if ( lower.endsWith( ".html" ) ) return "text/html"; + if ( lower.endsWith( ".css" ) ) return "text/css"; + if ( lower.endsWith( ".js" ) ) return "text/javascript"; + if ( lower.endsWith( ".json" ) ) return "application/json"; + if ( lower.endsWith( ".wasm" ) ) return "application/wasm"; + if ( lower.endsWith( ".png" ) ) return "image/png"; + if ( lower.endsWith( ".jpg" ) ) return "image/jpeg"; + if ( lower.endsWith( ".jpeg" ) ) return "image/jpeg"; + if ( lower.endsWith( ".gif" ) ) return "image/gif"; + if ( lower.endsWith( ".svg" ) ) return "image/svg+xml"; + if ( lower.endsWith( ".wgsl" ) ) return "text/plain"; + if ( lower.endsWith( ".txt" ) ) return "text/plain"; + + return "application/octet-stream"; + + } + + +} + + +const app = new App( ); + +await app.start( ); diff --git a/shaders/neural/forward_dense_layer1.wgsl b/shaders/neural/forward_dense_layer1.wgsl new file mode 100644 index 0000000..107d1c6 --- /dev/null +++ b/shaders/neural/forward_dense_layer1.wgsl @@ -0,0 +1,33 @@ +@group(0) @binding(0) var inputImages : array; +@group(0) @binding(1) var weightLayer1 : array; +@group(0) @binding(2) var biasLayer1 : array; +@group(0) @binding(3) var hiddenActivations : array; +@group(0) @binding(4) var layer1Sizes : vec3; +// x = inputDim, y = hiddenDim, z = batchSize + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id : vec3) { + let index = global_id.x; + + let inputDim = layer1Sizes.x; + let hiddenDim = layer1Sizes.y; + let batchSize = layer1Sizes.z; + + let totalElements = batchSize * hiddenDim; + if (index >= totalElements) { return; } + + let sampleIndex = index / hiddenDim; + let hiddenIndex = index % hiddenDim; + + var sum : f32 = biasLayer1[hiddenIndex]; + + let baseInputIndex = sampleIndex * inputDim; + let baseWeightIndex = hiddenIndex * inputDim; + + for (var i : u32 = 0u; i < inputDim; i++) { + sum += inputImages[baseInputIndex + i] * + weightLayer1[baseWeightIndex + i]; + } + + hiddenActivations[index] = max(sum, 0.0); +} diff --git a/shaders/neural/forward_dense_layer2.wgsl b/shaders/neural/forward_dense_layer2.wgsl new file mode 100644 index 0000000..89fea96 --- /dev/null +++ b/shaders/neural/forward_dense_layer2.wgsl @@ -0,0 +1,33 @@ +@group(0) @binding(0) var hiddenActivations : array; +@group(0) @binding(1) var weightLayer2 : array; +@group(0) @binding(2) var biasLayer2 : array; +@group(0) @binding(3) var outputLogits : array; +@group(0) @binding(4) var layer2Sizes : vec3; +// x = hiddenDim, y = numClasses, z = batchSize + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) id : vec3) { + let index = id.x; + + let hiddenDim = layer2Sizes.x; + let numClasses = layer2Sizes.y; + let batchSize = layer2Sizes.z; + + let total = batchSize * numClasses; + if (index >= total) { return; } + + let sampleIndex = index / numClasses; + let classIndex = index % numClasses; + + var sum : f32 = biasLayer2[classIndex]; + + let baseHiddenIndex = sampleIndex * hiddenDim; + let baseWeightIndex = classIndex * hiddenDim; + + for (var i : u32 = 0u; i < hiddenDim; i++) { + sum += hiddenActivations[baseHiddenIndex + i] * + weightLayer2[baseWeightIndex + i]; + } + + outputLogits[index] = sum; +} diff --git a/shaders/neural/gradients_dense_layer1.wgsl b/shaders/neural/gradients_dense_layer1.wgsl new file mode 100644 index 0000000..2dd5b17 --- /dev/null +++ b/shaders/neural/gradients_dense_layer1.wgsl @@ -0,0 +1,85 @@ +@group(0) @binding(0) +var inputImages : array; +// [batchSize * inputDimension] + +@group(0) @binding(1) +var gradientHidden : array; +// [batchSize * hiddenDimension] + +@group(0) @binding(2) +var gradientWeightLayer1 : array; +// [inputDimension * hiddenDimension] + +@group(0) @binding(3) +var gradientBiasLayer1 : array; +// [hiddenDimension] + +@group(0) @binding(4) +var layer1Sizes : vec3; +// x = inputDimension, y = hiddenDimension, z = batchSize + + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) id : vec3) { + + let index = id.x; + + let inputDimension = layer1Sizes.x; + let hiddenDimension = layer1Sizes.y; + let batchSize = layer1Sizes.z; + + // We launch: + // inputDimension * hiddenDimension threads for weight gradients + // + hiddenDimension threads for bias gradients + let totalWeightElements = inputDimension * hiddenDimension; + let totalBiasElements = hiddenDimension; + let totalElements = totalWeightElements + totalBiasElements; + + if (index >= totalElements) { + return; + } + + // ---- Region 1: Weight gradients dW1 (0 .. totalWeightElements-1) ---- + if (index < totalWeightElements) { + + // Map flat index into (inputIndex, hiddenIndex) + let hiddenIndex : u32 = index / inputDimension; + let inputIndex : u32 = index % inputDimension; + + var sum : f32 = 0.0; + + // dW1[inputIndex, hiddenIndex] = sum over batch of + // inputImages[b, inputIndex] * gradientHidden[b, hiddenIndex] + for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) { + + let inputValue = inputImages[ batchIndex * inputDimension + inputIndex ]; + + let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + hiddenIndex ]; + + sum = sum + inputValue * hiddenGradient; + + } + + gradientWeightLayer1[ index ] = sum; + return; + } + + // ---- Region 2: Bias gradients db1 (totalWeightElements .. end) ---- + let biasIndex : u32 = index - totalWeightElements; + + if (biasIndex < hiddenDimension) { + + var sumBias : f32 = 0.0; + + // db1[hiddenIndex] = sum over batch of gradientHidden[b, hiddenIndex] + for (var batchIndex : u32 = 0u; batchIndex < batchSize; batchIndex = batchIndex + 1u) { + + let hiddenGradient = gradientHidden[ batchIndex * hiddenDimension + biasIndex ]; + + sumBias = sumBias + hiddenGradient; + + } + + gradientBiasLayer1[ biasIndex ] = sumBias; + } +} diff --git a/shaders/neural/gradients_dense_layer2.wgsl b/shaders/neural/gradients_dense_layer2.wgsl new file mode 100644 index 0000000..d81cfb4 --- /dev/null +++ b/shaders/neural/gradients_dense_layer2.wgsl @@ -0,0 +1,81 @@ +// gradients_dense_layer2.wgsl + +@group(0) @binding(0) +var hiddenActivations : array; // [batch * hiddenDim] + +@group(0) @binding(1) +var gradientLogits : array; // [batch * numClasses] + +@group(0) @binding(2) +var weightLayer2 : array; // [hiddenDim * numClasses] <-- Missing before + +@group(0) @binding(3) +var gradientWeightLayer2 : array; // [hiddenDim * numClasses] + +@group(0) @binding(4) +var gradientBiasLayer2 : array; // [numClasses] + +@group(0) @binding(5) +var gradientHidden : array; // [batch * hiddenDim] + +@group(0) @binding(6) +var layer2Sizes : vec3; +// x = hiddenDim, y = numClasses, z = batchSize + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) id : vec3) { + + let globalIndex = id.x; + + let hiddenDim = layer2Sizes.x; + let numClasses = layer2Sizes.y; + let batchSize = layer2Sizes.z; + + // + // (1) Compute gradients for W2 and b2 + // + let totalWeights = hiddenDim * numClasses; + + if (globalIndex < totalWeights) { + let classIndex = globalIndex / hiddenDim; + let hiddenIndex = globalIndex % hiddenDim; + + var sum : f32 = 0.0; + + for (var b : u32 = 0u; b < batchSize; b++) { + let activationValue = hiddenActivations[b * hiddenDim + hiddenIndex]; + let gradLogitValue = gradientLogits[b * numClasses + classIndex]; + sum += activationValue * gradLogitValue; + } + + gradientWeightLayer2[globalIndex] = sum; + + if (hiddenIndex == 0u) { + var biasSum : f32 = 0.0; + for (var b2 : u32 = 0u; b2 < batchSize; b2++) { + biasSum += gradientLogits[b2 * numClasses + classIndex]; + } + gradientBiasLayer2[classIndex] = biasSum; + } + } + + // + // (2) Compute gradientHidden for layer1 + // + let dhIndex = globalIndex - totalWeights; + + if (dhIndex < batchSize * hiddenDim) { + let b = dhIndex / hiddenDim; + let hiddenIndex = dhIndex % hiddenDim; + + var sum2 : f32 = 0.0; + + for (var c : u32 = 0u; c < numClasses; c++) { + let gradLogitValue = gradientLogits[b * numClasses + c]; + let weightValue = weightLayer2[c * hiddenDim + hiddenIndex]; + sum2 += gradLogitValue * weightValue; + } + + gradientHidden[dhIndex] = sum2; + } +} diff --git a/shaders/neural/sgd_optimizer_update.wgsl b/shaders/neural/sgd_optimizer_update.wgsl new file mode 100644 index 0000000..f1b873b --- /dev/null +++ b/shaders/neural/sgd_optimizer_update.wgsl @@ -0,0 +1,39 @@ +@group(0) @binding(0) var weightLayer1 : array; +@group(0) @binding(1) var biasLayer1 : array; +@group(0) @binding(2) var gradientWeightLayer1 : array; +@group(0) @binding(3) var gradientBiasLayer1 : array; + +@group(0) @binding(4) var weightLayer2 : array; +@group(0) @binding(5) var biasLayer2 : array; +@group(0) @binding(6) var gradientWeightLayer2 : array; +@group(0) @binding(7) var gradientBiasLayer2 : array; + +@group(0) @binding(8) var learningRate : f32; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) id: vec3) { + let index = id.x; + + if (index < arrayLength(&weightLayer1)) { + weightLayer1[index] -= gradientWeightLayer1[index] * learningRate; + return; + } + + let b1Index = index - arrayLength(&weightLayer1); + if (b1Index < arrayLength(&biasLayer1)) { + biasLayer1[b1Index] -= gradientBiasLayer1[b1Index] * learningRate; + return; + } + + let w2Index = index - arrayLength(&weightLayer1) - arrayLength(&biasLayer1); + if (w2Index < arrayLength(&weightLayer2)) { + weightLayer2[w2Index] -= gradientWeightLayer2[w2Index] * learningRate; + return; + } + + let b2Index = w2Index - arrayLength(&weightLayer2); + if (b2Index < arrayLength(&biasLayer2)) { + biasLayer2[b2Index] -= gradientBiasLayer2[b2Index] * learningRate; + return; + } +} diff --git a/shaders/neural/softmax_cross_entropy_backward.wgsl b/shaders/neural/softmax_cross_entropy_backward.wgsl new file mode 100644 index 0000000..7b27a49 --- /dev/null +++ b/shaders/neural/softmax_cross_entropy_backward.wgsl @@ -0,0 +1,39 @@ +@group(0) @binding(0) var outputLogits : array; +@group(0) @binding(1) var targetLabels : array; +@group(0) @binding(2) var gradientLogits : array; +@group(0) @binding(3) var crossEntropyLoss : array; +@group(0) @binding(4) var layer2Sizes : vec3; +// x = hiddenDim, y = numClasses, z = batchSize + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) id: vec3) { + let sampleIndex = id.x; + let batchSize = layer2Sizes.z; + let numClasses = layer2Sizes.y; + + if (sampleIndex >= batchSize) { return; } + + let base = sampleIndex * numClasses; + let label = targetLabels[sampleIndex]; + + var maxLogit : f32 = -1e9; + for (var c : u32 = 0u; c < numClasses; c++) { + let v = outputLogits[base + c]; + if (v > maxLogit) { maxLogit = v; } + } + + var sumExp : f32 = 0.0; + for (var c : u32 = 0u; c < numClasses; c++) { + sumExp += exp(outputLogits[base + c] - maxLogit); + } + + var loss : f32 = 0.0; + for (var c : u32 = 0u; c < numClasses; c++) { + let sm = exp(outputLogits[base + c] - maxLogit) / sumExp; + let oneHot = f32(c == label); + gradientLogits[base + c] = sm - oneHot; + if (oneHot == 1.0) { loss = -log(sm); } + } + + crossEntropyLoss[sampleIndex] = loss; +} diff --git a/source/index_old.html b/source/index_old.html new file mode 100644 index 0000000..652a25b --- /dev/null +++ b/source/index_old.html @@ -0,0 +1,14 @@ + + + + + WebGPU Training Demo + + + + +

WebGPU Linear Regression Training

+ +

+
+
\ No newline at end of file
diff --git a/source/loadMNIST.js b/source/loadMNIST.js
new file mode 100644
index 0000000..dd85265
--- /dev/null
+++ b/source/loadMNIST.js
@@ -0,0 +1,40 @@
+export async function loadBinary(url) {
+    const response = await fetch(url);
+    if (!response.ok) throw new Error("Failed to load " + url);
+    const buffer = await response.arrayBuffer();
+    return buffer;
+}
+
+// Loads MNIST 28x28 grayscale images + labels
+export async function loadMNIST() {
+    const imgBuf = await loadBinary("/models/MNIST-2k-images.bin");
+    const lblBuf = await loadBinary("/models/MNIST-2k-labels.bin");
+
+    const images = new Float32Array(imgBuf); // already normalized [0..1]
+    const labels = new Uint8Array(lblBuf);
+
+    const sampleCount = labels.length;
+    console.log("MNIST Loaded:", sampleCount, "samples");
+
+    return { images, labels, sampleCount };
+}
+
+// Utility visualization helper for debugging
+export function showImage(img784, canvas) {
+    const size = 28;
+    canvas.width = size;
+    canvas.height = size;
+
+    const ctx = canvas.getContext("2d");
+    const imgData = ctx.createImageData(size, size);
+
+    for (let i = 0; i < size * size; i++) {
+        const v = Math.floor(img784[i] * 255);
+        imgData.data[i * 4 + 0] = v;
+        imgData.data[i * 4 + 1] = v;
+        imgData.data[i * 4 + 2] = v;
+        imgData.data[i * 4 + 3] = 255;
+    }
+
+    ctx.putImageData(imgData, 0, 0);
+}
diff --git a/source/mnist_loader.js b/source/mnist_loader.js
new file mode 100644
index 0000000..43b3af5
--- /dev/null
+++ b/source/mnist_loader.js
@@ -0,0 +1,27 @@
+export async function loadMNIST(device, sampleCount = 1000) {
+    async function loadFile(path) {
+        const res = await fetch(path);
+        return new Uint8Array(await res.arrayBuffer());
+    }
+
+    const imgRaw = await loadFile("../../models/train-images-idx3-ubyte");
+    const lblRaw = await loadFile("../../models/train-labels-idx1-ubyte");
+
+    const header = 16;
+    const fullCount = lblRaw.length - 8;
+    const count = Math.min(sampleCount, fullCount);
+
+    const images = new Float32Array(count * 28 * 28);
+    const labels = new Uint32Array(count);
+
+    for (let i = 0; i < count; i++) {
+        labels[i] = lblRaw[8 + i];
+        const src = header + i * 784;
+        const dst = i * 784;
+        for (let j = 0; j < 784; j++) {
+            images[dst + j] = imgRaw[src + j] / 255;
+        }
+    }
+
+    return { images, labels, count };
+}
diff --git a/source/train_images.html b/source/train_images.html
new file mode 100644
index 0000000..b81403b
--- /dev/null
+++ b/source/train_images.html
@@ -0,0 +1,22 @@
+
+
+
+    
+    WebGPU MNIST Training (Batch 64)
+    
+
+
+    

WebGPU MNIST Training — Batch 64

+ +
+
+ + + diff --git a/source/training.js b/source/training.js new file mode 100644 index 0000000..2040372 --- /dev/null +++ b/source/training.js @@ -0,0 +1,111 @@ +import Shader from "../../framework/WebGpu.js"; + +const N = 32; // training samples +const inputSize = 3; +const outputSize = 1; +const lr = 0.01; +const epochs = 50; + +function generateTrainingData() { + const X = new Float32Array(N * inputSize); + const y = new Float32Array(N); + + for (let i = 0; i < N; i++) { + const x1 = Math.random() * 2 - 1; + const x2 = Math.random() * 2 - 1; + const x3 = Math.random() * 2 - 1; + + const label = 3 * x1 + 2 * x2 - 1 * x3 + 0.5; + + X[i * 3 + 0] = x1; + X[i * 3 + 1] = x2; + X[i * 3 + 2] = x3; + + y[i] = label; + } + return { X, y }; +} + +async function main() { + if (!navigator.gpu) { + alert("WebGPU not supported!"); + return; + } + + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + + const { X, y } = generateTrainingData(); + + // --- 1: Create shaders --- + const forward = new Shader(device); + await forward.setup("../../shaders/neural/forward.wgsl"); + + const lossBackward = new Shader(device); + await lossBackward.setup("../../shaders/neural/loss_backward.wgsl"); + + const weightGrad = new Shader(device); + await weightGrad.setup("../../shaders/neural/weight_grad.wgsl"); + + const sgdUpdate = new Shader(device); + await sgdUpdate.setup("../../shaders/neural/sgd_update.wgsl"); + + // --- 2: Initialize model parameters --- + let W = new Float32Array(inputSize).map(() => Math.random() * 0.1); + let b = new Float32Array([0]); + + forward.setVariable("input", X); + forward.setVariable("W", W); + forward.setVariable("b", b); + forward.setVariable("N", N); + forward.setVariable("y", y); + + + + // link GPU buffers between shaders + lossBackward.setBuffer("y", forward.getBuffer("y")); + lossBackward.setBuffer("pred", forward.getBuffer("pred")); + lossBackward.setVariable("N", N); + + weightGrad.setBuffer("input", forward.getBuffer("input")); + weightGrad.setBuffer("dPred", lossBackward.getBuffer("dPred")); + weightGrad.setVariable("N", N); + + sgdUpdate.setBuffer("W", forward.getBuffer("W")); + sgdUpdate.setBuffer("b", forward.getBuffer("b")); + sgdUpdate.setBuffer("dW", weightGrad.getBuffer("dW")); + sgdUpdate.setBuffer("db", weightGrad.getBuffer("db")); + sgdUpdate.setVariable("lr", lr); + + const log = document.getElementById("log"); + + document.getElementById("trainBtn").onclick = async () => { + for (let epoch = 0; epoch < epochs; epoch++) { + await forward.execute(N); + await lossBackward.execute(N); + await weightGrad.execute(N); + await sgdUpdate.execute(inputSize); + + const lossArr = await lossBackward.debugBuffer("loss"); + const loss = lossArr.reduce((a, b) => a + b, 0) / N; + + log.textContent += `Epoch ${epoch}: loss = ${loss.toFixed(4)}\n`; + + // Debug learned weights every 5 epochs + if (epoch % 5 === 0) { + const w = await forward.debugBuffer("W"); + console.log( + `Epoch ${epoch}: W = ${w.map(v => v.toFixed(2)).join(", ")}` + ); + } + } + + const learnedW = await forward.debugBuffer("W"); + const learnedB = await forward.debugBuffer("b"); + + log.textContent += `\nLearned W: ${learnedW.map(v=>v.toFixed(3)).join(", ")}\n`; + log.textContent += `Learned b: ${learnedB[0].toFixed(3)}\n`; + }; +} + +main(); diff --git a/source/training2.js b/source/training2.js new file mode 100644 index 0000000..2a41ebb --- /dev/null +++ b/source/training2.js @@ -0,0 +1,223 @@ +import Shader from "../../framework/WebGpu.js"; +import { loadMNIST } from "./mnist_loader.js"; + +const batchSize = 1000; +const inputDimension = 28 * 28; +const hiddenDimension = 128; +const numberOfClasses = 10; +const learningRateValue = 0.05; +const numberOfEpochs = 20; + +function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) { + const context = canvas.getContext("2d"); + const image = context.createImageData(28, 28); + + for (let i = 0; i < inputDimension; i++) { + const grayscale = pixelData[i] * 255; + const pixelIndex = i * 4; + image.data[pixelIndex + 0] = grayscale; + image.data[pixelIndex + 1] = grayscale; + image.data[pixelIndex + 2] = grayscale; + image.data[pixelIndex + 3] = 255; + } + + context.putImageData(image, 0, 0); + canvas.nextLabel.textContent = + `GT: ${trueLabel} → Pred: ${predictedLabel}`; +} + +async function main() { + if (!navigator.gpu) { + alert("WebGPU not supported"); + return; + } + + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + + // Load MNIST subset → inputImages & targetLabels are GPU buffers + const mnist = await loadMNIST(device, batchSize); + const inputImages = mnist.images; + const targetLabels = mnist.labels; + + const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]); + const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]); + + // + // === Create Shader Instances === + // + const forwardDenseLayer1 = new Shader(device); + await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl"); + + const forwardDenseLayer2 = new Shader(device); + await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl"); + + const softmaxCrossEntropyBackward = new Shader(device); + await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl"); + + const gradientsDenseLayer2 = new Shader(device); + await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl"); + + const gradientsDenseLayer1 = new Shader(device); + await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl"); + + const sgdOptimizerUpdate = new Shader(device); + await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl"); + + // + // === Model Parameters === + // + const weightLayer1 = new Float32Array(inputDimension * hiddenDimension) + .map(() => (Math.random() * Math.sqrt(2.0 / inputDimension)) - (Math.sqrt(2.0 / inputDimension) / 2)); + + + const biasLayer1 = new Float32Array(hiddenDimension); + + const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses) + .map(() => (Math.random() * Math.sqrt(2.0 / hiddenDimension)) - (Math.sqrt(2.0 / hiddenDimension) / 2)); + + + const biasLayer2 = new Float32Array(numberOfClasses); + + // + // === Bind Buffers to Shaders === + // + forwardDenseLayer1.setVariable("inputImages", inputImages); + forwardDenseLayer1.setVariable("weightLayer1", weightLayer1); + forwardDenseLayer1.setVariable("biasLayer1", biasLayer1); + forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + forwardDenseLayer2.setVariable("weightLayer2", weightLayer2); + forwardDenseLayer2.setVariable("biasLayer2", biasLayer2); + forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits")); + softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels); + softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits")); + //gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null); + gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + //gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null); + //gradientsDenseLayer2.setBuffer("gradientHidden", null); + gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer1.setVariable("inputImages", inputImages); + gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden")); + //gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null); + //gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null); + gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1")); + sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1")); + + sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2")); + + sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform"); + + // + // === UI === + // + const logOutput = document.getElementById("log"); + const previewContainer = document.getElementById("preview"); + const numberOfPreviewImages = 6; + + const previewIndices = Array.from({ length: numberOfPreviewImages }, + () => Math.floor(Math.random() * batchSize)); + + const previewCanvases = previewIndices.map(() => { + const container = document.createElement("div"); + container.className = "previewItem"; + const canvas = document.createElement("canvas"); + canvas.width = canvas.height = 28; + const label = document.createElement("div"); + canvas.nextLabel = label; + container.appendChild(canvas); + container.appendChild(label); + previewContainer.appendChild(container); + return canvas; + }); + + // + // === Training === + // + document.getElementById("trainBtn").onclick = async () => { + for (let epoch = 0; epoch < numberOfEpochs; epoch++) { + + await forwardDenseLayer1.execute(batchSize); + await forwardDenseLayer2.execute(batchSize); + await softmaxCrossEntropyBackward.execute(batchSize); + + const wgSize = 64; + + // Grad Layer 2 + await gradientsDenseLayer2.execute( + Math.ceil((batchSize * numberOfClasses + hiddenDimension * numberOfClasses) / wgSize) + ); + + // Grad Layer 1 + await gradientsDenseLayer1.execute( + Math.ceil((batchSize * hiddenDimension + inputDimension * hiddenDimension) / wgSize) + ); + + // SGD + await sgdOptimizerUpdate.execute( + Math.ceil((weightLayer1.length + biasLayer1.length + + weightLayer2.length + biasLayer2.length) / wgSize) + ); + + // Loss + const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss"); + const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize; + logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`; + + // Accuracy + const logits = await forwardDenseLayer2.debugBuffer("outputLogits"); + let correctCount = 0; + + for (let i = 0; i < batchSize; i++) { + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[i*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + if (bestClass === targetLabels[i]) correctCount++; + } + + const accuracy = correctCount / batchSize; + logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`; + + // Update preview images + for (let k = 0; k < numberOfPreviewImages; k++) { + const index = previewIndices[k]; + const image = inputImages.subarray( + index*inputDimension, (index+1)*inputDimension + ); + + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[index*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + + drawMNISTPreview( + previewCanvases[k], + image, + targetLabels[index], + bestClass + ); + } + } + }; +} + +main(); diff --git a/source/training2_working.js b/source/training2_working.js new file mode 100644 index 0000000..fd4e71d --- /dev/null +++ b/source/training2_working.js @@ -0,0 +1,212 @@ +import Shader from "../../framework/WebGpu.js"; +import { loadMNIST } from "./mnist_loader.js"; + +const batchSize = 1000; +const inputDimension = 28 * 28; +const hiddenDimension = 128; +const numberOfClasses = 10; +const learningRateValue = 0.05; +const numberOfEpochs = 20; + +function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) { + const context = canvas.getContext("2d"); + const image = context.createImageData(28, 28); + + for (let i = 0; i < inputDimension; i++) { + const grayscale = pixelData[i] * 255; + const pixelIndex = i * 4; + image.data[pixelIndex + 0] = grayscale; + image.data[pixelIndex + 1] = grayscale; + image.data[pixelIndex + 2] = grayscale; + image.data[pixelIndex + 3] = 255; + } + + context.putImageData(image, 0, 0); + canvas.nextLabel.textContent = + `GT: ${trueLabel} → Pred: ${predictedLabel}`; +} + +async function main() { + if (!navigator.gpu) { + alert("WebGPU not supported"); + return; + } + + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + + // Load MNIST subset → inputImages & targetLabels are GPU buffers + const mnist = await loadMNIST(device, batchSize); + const inputImages = mnist.images; + const targetLabels = mnist.labels; + + const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]); + const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]); + + // + // === Create Shader Instances === + // + const forwardDenseLayer1 = new Shader(device); + await forwardDenseLayer1.setup("../../shaders/neural/forward_dense_layer1.wgsl"); + + const forwardDenseLayer2 = new Shader(device); + await forwardDenseLayer2.setup("../../shaders/neural/forward_dense_layer2.wgsl"); + + const softmaxCrossEntropyBackward = new Shader(device); + await softmaxCrossEntropyBackward.setup("../../shaders/neural/softmax_cross_entropy_backward.wgsl"); + + const gradientsDenseLayer2 = new Shader(device); + await gradientsDenseLayer2.setup("../../shaders/neural/gradients_dense_layer2.wgsl"); + + const gradientsDenseLayer1 = new Shader(device); + await gradientsDenseLayer1.setup("../../shaders/neural/gradients_dense_layer1.wgsl"); + + const sgdOptimizerUpdate = new Shader(device); + await sgdOptimizerUpdate.setup("../../shaders/neural/sgd_optimizer_update.wgsl"); + + // + // === Model Parameters === + // + const weightLayer1 = new Float32Array(inputDimension * hiddenDimension).map(() => (Math.random() - 0.5) * 0.05); + const biasLayer1 = new Float32Array(hiddenDimension); + + const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses).map(() => (Math.random() - 0.5) * 0.05); + const biasLayer2 = new Float32Array(numberOfClasses); + + // + // === Bind Buffers to Shaders === + // + forwardDenseLayer1.setVariable("inputImages", inputImages); + forwardDenseLayer1.setVariable("weightLayer1", weightLayer1); + forwardDenseLayer1.setVariable("biasLayer1", biasLayer1); + forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + forwardDenseLayer2.setVariable("weightLayer2", weightLayer2); + forwardDenseLayer2.setVariable("biasLayer2", biasLayer2); + forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits")); + softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels); + softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits")); + //gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null); + gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + //gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null); + //gradientsDenseLayer2.setBuffer("gradientHidden", null); + gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer1.setVariable("inputImages", inputImages); + gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden")); + //gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null); + //gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null); + gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1")); + sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1")); + + sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2")); + + sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform"); + + // + // === UI === + // + const logOutput = document.getElementById("log"); + const previewContainer = document.getElementById("preview"); + const numberOfPreviewImages = 6; + + const previewIndices = Array.from({ length: numberOfPreviewImages }, + () => Math.floor(Math.random() * batchSize)); + + const previewCanvases = previewIndices.map(() => { + const container = document.createElement("div"); + container.className = "previewItem"; + const canvas = document.createElement("canvas"); + canvas.width = canvas.height = 28; + const label = document.createElement("div"); + canvas.nextLabel = label; + container.appendChild(canvas); + container.appendChild(label); + previewContainer.appendChild(container); + return canvas; + }); + + // + // === Training === + // + document.getElementById("trainBtn").onclick = async () => { + for (let epoch = 0; epoch < numberOfEpochs; epoch++) { + + await forwardDenseLayer1.execute(batchSize); + await forwardDenseLayer2.execute(batchSize); + await softmaxCrossEntropyBackward.execute(batchSize); + + await gradientsDenseLayer2.execute( + hiddenDimension * numberOfClasses + numberOfClasses + batchSize * hiddenDimension + ); + + await gradientsDenseLayer1.execute( + inputDimension * hiddenDimension + hiddenDimension + ); + + await sgdOptimizerUpdate.execute( + weightLayer1.length + biasLayer1.length + + weightLayer2.length + biasLayer2.length + ); + + // Loss + const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss"); + const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize; + logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`; + + // Accuracy + const logits = await forwardDenseLayer2.debugBuffer("outputLogits"); + let correctCount = 0; + + for (let i = 0; i < batchSize; i++) { + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[i*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + if (bestClass === targetLabels[i]) correctCount++; + } + + const accuracy = correctCount / batchSize; + logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`; + + // Update preview images + for (let k = 0; k < numberOfPreviewImages; k++) { + const index = previewIndices[k]; + const image = inputImages.subarray( + index*inputDimension, (index+1)*inputDimension + ); + + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[index*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + + drawMNISTPreview( + previewCanvases[k], + image, + targetLabels[index], + bestClass + ); + } + } + }; +} + +main(); diff --git a/training.js b/training.js new file mode 100644 index 0000000..89860ee --- /dev/null +++ b/training.js @@ -0,0 +1,223 @@ +import Shader from "./framework/WebGpu.js"; +import { loadMNIST } from "./mnist_loader.js"; + +const batchSize = 1000; +const inputDimension = 28 * 28; +const hiddenDimension = 128; +const numberOfClasses = 10; +const learningRateValue = 0.05; +const numberOfEpochs = 20; + +function drawMNISTPreview(canvas, pixelData, trueLabel, predictedLabel) { + const context = canvas.getContext("2d"); + const image = context.createImageData(28, 28); + + for (let i = 0; i < inputDimension; i++) { + const grayscale = pixelData[i] * 255; + const pixelIndex = i * 4; + image.data[pixelIndex + 0] = grayscale; + image.data[pixelIndex + 1] = grayscale; + image.data[pixelIndex + 2] = grayscale; + image.data[pixelIndex + 3] = 255; + } + + context.putImageData(image, 0, 0); + canvas.nextLabel.textContent = + `GT: ${trueLabel} → Pred: ${predictedLabel}`; +} + +async function main() { + if (!navigator.gpu) { + alert("WebGPU not supported"); + return; + } + + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + + // Load MNIST subset → inputImages & targetLabels are GPU buffers + const mnist = await loadMNIST(device, batchSize); + const inputImages = mnist.images; + const targetLabels = mnist.labels; + + const layer1Sizes = new Uint32Array([inputDimension, hiddenDimension, batchSize]); + const layer2Sizes = new Uint32Array([hiddenDimension, numberOfClasses, batchSize]); + + // + // === Create Shader Instances === + // + const forwardDenseLayer1 = new Shader(device); + await forwardDenseLayer1.setup("./shaders/neural/forward_dense_layer1.wgsl"); + + const forwardDenseLayer2 = new Shader(device); + await forwardDenseLayer2.setup("./shaders/neural/forward_dense_layer2.wgsl"); + + const softmaxCrossEntropyBackward = new Shader(device); + await softmaxCrossEntropyBackward.setup("./shaders/neural/softmax_cross_entropy_backward.wgsl"); + + const gradientsDenseLayer2 = new Shader(device); + await gradientsDenseLayer2.setup("./shaders/neural/gradients_dense_layer2.wgsl"); + + const gradientsDenseLayer1 = new Shader(device); + await gradientsDenseLayer1.setup("./shaders/neural/gradients_dense_layer1.wgsl"); + + const sgdOptimizerUpdate = new Shader(device); + await sgdOptimizerUpdate.setup("./shaders/neural/sgd_optimizer_update.wgsl"); + + // + // === Model Parameters === + // + const weightLayer1 = new Float32Array(inputDimension * hiddenDimension) + .map(() => (Math.random() * Math.sqrt(2.0 / inputDimension)) - (Math.sqrt(2.0 / inputDimension) / 2)); + + + const biasLayer1 = new Float32Array(hiddenDimension); + + const weightLayer2 = new Float32Array(hiddenDimension * numberOfClasses) + .map(() => (Math.random() * Math.sqrt(2.0 / hiddenDimension)) - (Math.sqrt(2.0 / hiddenDimension) / 2)); + + + const biasLayer2 = new Float32Array(numberOfClasses); + + // + // === Bind Buffers to Shaders === + // + forwardDenseLayer1.setVariable("inputImages", inputImages); + forwardDenseLayer1.setVariable("weightLayer1", weightLayer1); + forwardDenseLayer1.setVariable("biasLayer1", biasLayer1); + forwardDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + forwardDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + forwardDenseLayer2.setVariable("weightLayer2", weightLayer2); + forwardDenseLayer2.setVariable("biasLayer2", biasLayer2); + forwardDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + softmaxCrossEntropyBackward.setBuffer("outputLogits", forwardDenseLayer2.getBuffer("outputLogits")); + softmaxCrossEntropyBackward.setVariable("targetLabels", targetLabels); + softmaxCrossEntropyBackward.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer2.setBuffer("hiddenActivations", forwardDenseLayer1.getBuffer("hiddenActivations")); + gradientsDenseLayer2.setBuffer("gradientLogits", softmaxCrossEntropyBackward.getBuffer("gradientLogits")); + //gradientsDenseLayer2.setBuffer("gradientWeightLayer2", null); + gradientsDenseLayer2.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + //gradientsDenseLayer2.setBuffer("gradientBiasLayer2", null); + //gradientsDenseLayer2.setBuffer("gradientHidden", null); + gradientsDenseLayer2.setVariable("layer2Sizes", layer2Sizes, "uniform"); + + gradientsDenseLayer1.setVariable("inputImages", inputImages); + gradientsDenseLayer1.setBuffer("gradientHidden", gradientsDenseLayer2.getBuffer("gradientHidden")); + //gradientsDenseLayer1.setBuffer("gradientWeightLayer1", null); + //gradientsDenseLayer1.setBuffer("gradientBiasLayer1", null); + gradientsDenseLayer1.setVariable("layer1Sizes", layer1Sizes, "uniform"); + + sgdOptimizerUpdate.setBuffer("weightLayer1", forwardDenseLayer1.getBuffer("weightLayer1")); + sgdOptimizerUpdate.setBuffer("biasLayer1", forwardDenseLayer1.getBuffer("biasLayer1")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer1", gradientsDenseLayer1.getBuffer("gradientWeightLayer1")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer1", gradientsDenseLayer1.getBuffer("gradientBiasLayer1")); + + sgdOptimizerUpdate.setBuffer("weightLayer2", forwardDenseLayer2.getBuffer("weightLayer2")); + sgdOptimizerUpdate.setBuffer("biasLayer2", forwardDenseLayer2.getBuffer("biasLayer2")); + sgdOptimizerUpdate.setBuffer("gradientWeightLayer2", gradientsDenseLayer2.getBuffer("gradientWeightLayer2")); + sgdOptimizerUpdate.setBuffer("gradientBiasLayer2", gradientsDenseLayer2.getBuffer("gradientBiasLayer2")); + + sgdOptimizerUpdate.setVariable("learningRate", new Float32Array([learningRateValue]), "uniform"); + + // + // === UI === + // + const logOutput = document.getElementById("log"); + const previewContainer = document.getElementById("preview"); + const numberOfPreviewImages = 6; + + const previewIndices = Array.from({ length: numberOfPreviewImages }, + () => Math.floor(Math.random() * batchSize)); + + const previewCanvases = previewIndices.map(() => { + const container = document.createElement("div"); + container.className = "previewItem"; + const canvas = document.createElement("canvas"); + canvas.width = canvas.height = 28; + const label = document.createElement("div"); + canvas.nextLabel = label; + container.appendChild(canvas); + container.appendChild(label); + previewContainer.appendChild(container); + return canvas; + }); + + // + // === Training === + // + document.getElementById("trainBtn").onclick = async () => { + for (let epoch = 0; epoch < numberOfEpochs; epoch++) { + + await forwardDenseLayer1.execute(batchSize); + await forwardDenseLayer2.execute(batchSize); + await softmaxCrossEntropyBackward.execute(batchSize); + + const wgSize = 64; + + // Grad Layer 2 + await gradientsDenseLayer2.execute( + Math.ceil((batchSize * numberOfClasses + hiddenDimension * numberOfClasses) / wgSize) + ); + + // Grad Layer 1 + await gradientsDenseLayer1.execute( + Math.ceil((batchSize * hiddenDimension + inputDimension * hiddenDimension) / wgSize) + ); + + // SGD + await sgdOptimizerUpdate.execute( + Math.ceil((weightLayer1.length + biasLayer1.length + + weightLayer2.length + biasLayer2.length) / wgSize) + ); + + // Loss + const lossValues = await softmaxCrossEntropyBackward.debugBuffer("crossEntropyLoss"); + const averageLoss = lossValues.reduce((a,b)=>a+b,0) / batchSize; + logOutput.textContent += `Epoch ${epoch}: Loss = ${averageLoss.toFixed(4)}\n`; + + // Accuracy + const logits = await forwardDenseLayer2.debugBuffer("outputLogits"); + let correctCount = 0; + + for (let i = 0; i < batchSize; i++) { + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[i*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + if (bestClass === targetLabels[i]) correctCount++; + } + + const accuracy = correctCount / batchSize; + logOutput.textContent += `Accuracy: ${(accuracy*100).toFixed(2)}%\n`; + + // Update preview images + for (let k = 0; k < numberOfPreviewImages; k++) { + const index = previewIndices[k]; + const image = inputImages.subarray( + index*inputDimension, (index+1)*inputDimension + ); + + let bestValue = -1e9; + let bestClass = 0; + for (let c = 0; c < numberOfClasses; c++) { + const value = logits[index*numberOfClasses+c]; + if (value > bestValue) { bestValue = value; bestClass = c; } + } + + drawMNISTPreview( + previewCanvases[k], + image, + targetLabels[index], + bestClass + ); + } + } + }; +} + +main();