Files
WebGPU-Neural-Network-Trainer/framework/WebGpu_node.js

1961 lines
38 KiB
JavaScript
Raw Normal View History

2025-11-19 10:42:46 +01:00
import { readFileSync } from "node:fs";
import gpu from '@kmamal/gpu'
if( typeof document != "undefined" ) {
document.shaders = new Array();
document.bufferMap = {};
}
var GPUShaderStage = gpu.GPUShaderStage;
var GPUTextureUsage = gpu.GPUTextureUsage;
var GPUBufferUsage = gpu.GPUBufferUsage;
var GPUColorWrite = gpu.GPUColorWrite;
var GPUTexture = gpu.GPUTexture;
var GPUSampler = gpu.GPUSampler;
var GPUMapMode = gpu.GPUMapMode;
export default class Shader {
sampleCount = 1;
device;
path;
wgslSource;
bindGroupLayout;
bindGroupLayouts = new Map();
vertexBuffersDescriptors = new Map();
textures = new Map();
pipeline;
bindGroups = new Map();
entryPoint = "main";
bindings = new Array();
buffers = new Map();
attributeBuffers ;
bufferUsages = new Map(); // store usage per buffer to recreate correctly
_externalBindGroupLayout = null;
hasLoaded = false;
stage = GPUShaderStage.COMPUTE;
isInWorker = typeof document === "undefined";
topologyValue = "triangle-strip";
binded = false;
textures = new Map();
samplers = new Map();
constructor( device, path = false ) {
if( !this.isInWorker ) {
document.shaders.push( this );
}
console.log("\n\n \n Creating New Shader", path, "\n\n");
console.log("device", typeof device);
this.device = device;
if( path ) {
this.path = path;
}
}
set topology(value) {
if (this.binded) {
console.error("Define topology before setup", this);
}
this.topologyValue = value;
}
get topology() {
return this.topologyValue;
}
_createAllBuffers() {
console.log("_createAllBuffers", this.bindings, this.buffers);
for ( const b of this.bindings ) {
//console.log("has buffer", b.varName, this.buffers.has( b.varName ));
if ( this.buffers.has( b.varName ) ) continue;
let size = 240000 * 3 * 4;
if ( b.type === "uniform" ) {
size = 4 * 4;
}
let usage;
if ( b.type === "uniform" ) {
usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST" );
} else if ( b.type === "read-only-storage" ) {
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST;
console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST" );
} else {
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
console.log( "buffer ", b.varName ," usage flags GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC" );
}
const buffer = this.device.createBuffer( {
size,
usage,
label: b.varName
} );
if( !this.isInWorker ) {
//document.bufferOwners.set( buffer, this );
}
this.buffers.set( b.varName, buffer );
this.bufferUsages.set( b.varName, usage );
//console.log( `Created buffer for '${b.varName}' size=${size} usage=0x${usage.toString(16)}` );
}
}
_createDefaultTexturesAndSamplers() {
for ( const b of this.bindings ) {
if ( b.type === "texture" && !this.textures.has( b.varName ) ) {
// Detect if the type is a texture array
const isTextureArray = b.varType.startsWith("texture_2d_array");
// Create 1x1 texture with 1 layer (depthOrArrayLayers = 1)
const defaultTexture = this.device.createTexture({
size: isTextureArray ? [1, 1, 2] : [1, 1, 1],
format: "rgba8unorm",
usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST,
label: b.varName + "_defaultTexture"
});
console.log("create texture ", isTextureArray, isTextureArray ? "2d-array" : "2d");
// Create texture view with correct dimension
const defaultTextureView = defaultTexture.createView({
dimension: isTextureArray ? "2d-array" : "2d"
});
console.log("create default texture?", b, defaultTextureView, isTextureArray ? "2d-array" : "2d");
this.textures.set(b.varName, defaultTextureView);
} else if ( b.type === "sampler" && !this.samplers.has( b.varName ) ) {
// Create default sampler
const defaultSampler = this.device.createSampler( {
magFilter: "linear",
minFilter: "linear",
label: b.varName + "_defaultSampler"
} );
this.samplers.set( b.varName, defaultSampler );
}
}
}
createBindGroups() {
this.bindGroups = new Map();
for ( const [ group, layout ] of this.bindGroupLayouts.entries() ) {
const bindings = this.bindings.filter( b => b.group === group );
const entries = new Array( bindings.length );
for ( let i = 0; i < bindings.length; i++ ) {
const b = bindings[ i ];
let resource;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
const gpuBuffer = this.buffers.get( b.varName );
if ( !gpuBuffer ) {
throw new Error( `Buffer missing for ${b.varName} when creating bind group.` );
}
resource = { buffer: gpuBuffer };
} else if ( b.type === "texture" ) {
const gpuTextureView = this.textures.get( b.varName );
console.log("this.textures", this.textures);
if ( !gpuTextureView ) {
console.warn( `Texture missing for ${b.varName} when creating bind group.` );
} else {
console.log( `Binding texture ${b.varName} with dimension:`, gpuTextureView.dimension );
}
resource = gpuTextureView;
} else if ( b.type === "sampler" ) {
const gpuSampler = this.samplers.get( b.varName );
if ( !gpuSampler ) {
//throw new Error( `Sampler missing for ${b.varName} when creating bind group.` );
}
resource = gpuSampler;
} else {
throw new Error( `Unknown binding type '${b.type}' for ${b.varName}` );
}
entries[ i ] = {
binding: b.binding,
resource: resource
};
}
const bindGroup = this.device.createBindGroup( {
layout: layout,
entries: entries
} );
this.bindGroups.set( group, bindGroup );
}
}
extractEntryPoints( wgslCode ) {
const entryPoints = [];
// Regex to match lines like: @vertex fn functionName(...) { ... }
const regex = /@(\w+)(?:\s+@\w+(?:\([^)]*\))?)*\s+fn\s+(\w+)\s*\(/g;
let match;
while ((match = regex.exec(wgslCode)) !== null) {
const stage = match[1];
const name = match[2];
let type;
switch (stage) {
case "vertex":
type = GPUShaderStage.VERTEX;
break;
case "fragment":
type = GPUShaderStage.FRAGMENT;
break;
case "compute":
type = GPUShaderStage.COMPUTE;
break;
default:
type = "Unknown";
}
entryPoints.push({ name: name, type: type });
}
return entryPoints;
}
async setup( path ) {
this.path = path;
this.wgslSource = await this._loadShaderSource( this.path );
//console.log(this.wgslSource);
var entryPoints = this.extractEntryPoints( this.wgslSource );
for (var i = 0; i < entryPoints.length; i++) {
var entryPoint = entryPoints[i];
await this.addStage( entryPoint.name, entryPoint.type );
}
//await this.addStage("computeMain", 4 );
//console.log( "load shader", this.wgslSource, entryPoints )
}
async _loadShaderSource( pathName ) {
//const response = await fetch( pathName );
if ( typeof window === 'undefined' ) {
const vertexShaderCode = await readFileSync( pathName, 'utf8' )
return vertexShaderCode;
} else {
const vertexShaderFile = path.join(__dirname, 'vertex.wgsl')
if ( !response.ok ) throw new Error( `Failed to load shader: ${ pathName }` );
return await response.text();
}
}
setVertexBuffer( name, dataArray ) {
const buffer = this.device.createBuffer({
size: dataArray.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation: true
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( dataArray );
buffer.unmap();
this.buffers.set( name, buffer );
}
extractVertexFunctionParameters(wgsl) {
let vertexFnIndex = wgsl.indexOf('@vertex');
if (vertexFnIndex === -1) return '';
let fnIndex = wgsl.indexOf('fn', vertexFnIndex);
if (fnIndex === -1) return '';
let parenStart = wgsl.indexOf('(', fnIndex);
if (parenStart === -1) return '';
let parenCount = 1;
let i = parenStart + 1;
while (i < wgsl.length && parenCount > 0) {
const c = wgsl[i];
if (c === '(') parenCount++;
else if (c === ')') parenCount--;
i++;
}
if (parenCount !== 0) {
// unbalanced parentheses
return '';
}
// Extract parameter substring between parenStart+1 and i-1
let paramsStr = wgsl.substring(parenStart + 1, i - 1);
return paramsStr.trim();
}
extractVertexParamsLocations(wgsl) {
const paramsStr = this.extractVertexFunctionParameters(wgsl)
const paramRegex = /@location\s*\(\s*(\d+)\s*\)\s+(\w+)\s*:\s*([\w<>\d]+(\s*<[^>]+>)?)/g;
const results = new Array();
let paramMatch;
while ((paramMatch = paramRegex.exec(paramsStr)) !== null) {
const location = Number(paramMatch[1]);
const varName = paramMatch[2];
const type = paramMatch[3].trim();
results.push({ location, varName, type });
}
return results;
}
getFormatSize( format ) {
switch ( format ) {
case "float32": return 4;
case "float32x2": return 8;
case "float32x3": return 12;
case "float32x4": return 16;
case "uint32": return 4;
case "sint32": return 4;
default: throw new Error("Unsupported format size for: " + format);
}
}
mapTypeToVertexFormat( type ) {
// Map WGSL types to GPUVertexFormat strings (example subset)
switch ( type ) {
case "f32": return "float32";
case "vec2<f32>": return "float32x2";
case "vec3<f32>": return "float32x3";
case "vec4<f32>": return "float32x4";
case "u32": return "uint32";
case "i32": return "sint32";
default: throw new Error("Unsupported vertex attribute type: " + type);
}
}
_parseVertexAttributes( wgsl ) {
const locations = this.extractVertexParamsLocations( wgsl );
let offset = 0;
const attributes = new Array();
for ( let attr of locations ) {
const format = this.mapTypeToVertexFormat( attr.type );
attributes.push({
shaderLocation : attr.location,
offset : offset,
format : format
});
offset += this.getFormatSize( format );
}
const vertexBufferDescriptor = {
arrayStride: offset,
attributes: attributes,
stepMode: "vertex"
};
console.log("vertexBufferDescriptor", vertexBufferDescriptor);
this.vertexBuffersDescriptors.set("vertexBuffer0", vertexBufferDescriptor);
console.log("vertex buffer descriptors updated:", this.vertexBuffersDescriptors);
}
_parseBindings( wgsl ) {
console.log("parse bindings");
this.bindings = [];
// Make the `<...>` part optional by wrapping it in a non-capturing group with `?`
const regex = /@group\((\d+)\)\s+@binding\((\d+)\)\s+var(?:<(\w+)(?:,\s*(\w+))?>)?\s+(\w+)\s*:\s*([^;]+);/g;
let match;
while ((match = regex.exec(wgsl)) !== null) {
const group = Number(match[1]);
const binding = Number(match[2]);
const storageClass = match[3]; // e.g. "uniform", "storage", or undefined for textures/samplers
const access = match[4]; // e.g. "read", "write", or undefined
const varName = match[5];
const varType = match[6]; // WGSL type string like "mat4x4<f32>", "texture_2d<f32>", "sampler", etc.
let type;
if (storageClass === "storage") {
type = access === "read" ? "read-only-storage" : "storage";
} else if (storageClass === "uniform") {
type = "uniform";
} else if (varType.startsWith("texture_")) {
type = "texture";
} else if (varType === "sampler" || varType === "sampler_comparison") {
type = "sampler";
} else {
type = "uniform"; // fallback
}
this.bindings.push({ group, binding, type, varName, varType });
}
console.log("bindings", this.bindings);
this.bindings.sort((a, b) => a.binding - b.binding);
}
_createBindGroupLayoutsFromBindings() {
const groups = new Map();
for ( let i = 0; i < this.bindings.length; i++ ) {
const b = this.bindings[ i ];
if ( !groups.has( b.group ) ) groups.set( b.group, [] );
groups.get( b.group ).push( b );
}
const layouts = new Map();
for ( const groupEntry of groups.entries() ) {
const group = groupEntry[ 0 ];
const bindings = groupEntry[ 1 ];
let visibility;
if ( this.stage === GPUShaderStage.COMPUTE ) {
visibility = GPUShaderStage.COMPUTE;
} else {
visibility = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
const entries = new Array( bindings.length );
for ( let j = 0; j < bindings.length; j++ ) {
const b = bindings[ j ];
let entry;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: b.type }
};
} else if ( b.type === "texture" ) {
let viewDimension = "2d"; // default
if ( b.varType.startsWith("texture_2d_array") ) {
viewDimension = "2d-array";
} else if ( b.varType.startsWith("texture_cube") ) {
viewDimension = "cube";
} else if ( b.varType.startsWith("texture_3d") ) {
viewDimension = "3d";
}
console.log("viewDimension", viewDimension);
entry = {
binding: b.binding,
visibility: visibility,
texture: {
sampleType: "float",
viewDimension: viewDimension
}
};
} else if ( b.type === "sampler" ) {
entry = {
binding: b.binding,
visibility: visibility,
sampler: { type: "filtering" }
};
} else {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: "uniform" }
};
}
entries[ j ] = entry;
}
console.log("createBindGroupLayout ", entries);
const layout = this.device.createBindGroupLayout( { entries } );
layouts.set( group, layout );
}
return layouts;
}
async _createPipeline() {
const layoutsArray = this.bindGroupLayouts && this.bindGroupLayouts.size > 0
? Array.from( this.bindGroupLayouts.values() )
: [ this.bindGroupLayout ];
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path,
} );
const info = await shaderModule.compilationInfo();
if (info.messages.length > 0) {
for (const msg of info.messages) {
console[msg.type === "error" ? "error" : "warn"](
`[${msg.lineNum}:${msg.linePos}] ${msg.message}`
);
}
}
console.log("this.bindGroupLayouts.values()", this.bindGroupLayouts.values());
this.pipeline = this.device.createComputePipeline({
layout: this.device.createPipelineLayout({
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
}),
compute: {
module: shaderModule,
entryPoint: this.entryPoint,
},
} );
}
setBindingGroupLayout( bindGroupLayout ) {
this._externalBindGroupLayout = bindGroupLayout;
}
getBindingGroupLayout( bindingGroupIndex = 0 ) {
if ( this.bindGroupLayouts ) {
return this.bindGroupLayouts.get( bindingGroupIndex );
}
return this.bindGroupLayout;
}
getTypedArrayByVariableName( name, flatData ) {
let typedArray;
const bindingInfo = this.bindings.find( b => b.varName === name );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
switch( wgslType ) {
case "vec3<f32>":
// flatData.push( 0 );
break;
}
switch ( wgslType ) {
case "f32":
case "vec2<f32>":
case "vec3<f32>":
case "vec4<f32>":
case "array<f32>":
typedArray = new Float32Array( flatData );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
case "array<u32>":
typedArray = new Uint32Array( flatData );
break;
case "i32":
case "vec2<i32>":
case "vec3<i32>":
case "vec4<i32>":
case "array<i32>":
typedArray = new Int32Array( flatData );
break;
default:
// find struct and then the datatype
typedArray = new Float32Array( flatData );
//console.error( "Unknown WGSL type in setVariable: " + wgslType );
}
//console.log("wgslType", wgslType);
}
return typedArray;
}
setVariable( name, dataObject, customBufferUsage ) {
if ( dataObject instanceof GPUTexture ) {
// Handle texture: create and store texture view
const textureView = dataObject.createView();
this.textures.set( name, textureView );
// Recreate bind groups since texture changed
this.createBindGroups();
return;
}
if ( dataObject instanceof GPUSampler ) {
// Handle sampler
this.samplers.set( name, dataObject );
// Recreate bind groups since sampler changed
this.createBindGroups();
return;
}
const flatData = this.flattenStructure( dataObject );
//console.log( "flattenStructure", name, flatData );
const sizeBytes = flatData.length * 4;
let buffer = this.buffers.get( name );
if ( !buffer ) {
throw new Error( "Buffer not found for variable: " + name + ". Buffers must be created in setup. >>" + this.path );
}
if ( sizeBytes > buffer.size ) {
console.log( `Resizing buffer '${name}' from ${buffer.size} to ${sizeBytes} bytes.` );
var usage = this.bufferUsages.get( name ) || ( GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST );
if ( customBufferUsage ) {
usage = customBufferUsage;
}
const newBuffer = this.device.createBuffer( {
size : sizeBytes,
usage : usage,
label : name
} );
if( !this.isInWorker ) {
//document.bufferOwners.set( newBuffer, this );
}
this.buffers.set( name, newBuffer );
buffer = newBuffer;
// Recreate bind groups to bind new buffer
//console.log( "setVariable()" );
this.createBindGroups();
}
let typedArray;
if ( typeof dataObject === "number" ) {
typedArray = this.getTypedArrayByVariableName( name, flatData );
} else if( Array.isArray( dataObject ) ) {
typedArray = this.getTypedArrayByVariableName( name, flatData );
} else if ( dataObject instanceof Uint32Array ) {
typedArray = dataObject;
} else if ( typeof dataObject.numTokens === "number" && Number.isInteger( dataObject.numTokens ) ) {
typedArray = new Uint32Array( [ dataObject.numTokens ] );
} else if ( dataObject instanceof Int32Array || dataObject instanceof Float32Array ) {
typedArray = dataObject;
} else {
typedArray = new Float32Array( flatData );
}
if( name == "instancePositions") {
console.log("instancePositions", flatData, dataObject, typedArray.byteOffset, typedArray.byteLength, typedArray.buffer);
}
this.device.queue.writeBuffer(
buffer,
0,
typedArray.buffer,
typedArray.byteOffset,
typedArray.byteLength
);
//console.log( `Buffer data updated: ${name}, byteLength=${typedArray.byteLength}` );
}
flattenStructure( input, output = [] ) {
if ( Array.isArray( input ) ) {
for ( const element of input ) {
this.flattenStructure( element, output );
}
} else if ( typeof input === "object" && input !== null ) {
for ( const key of Object.keys( input ) ) {
this.flattenStructure( input[ key ], output );
}
} else if ( typeof input === "number" ) {
output.push( input );
} else if ( typeof input === "boolean" ) {
output.push( input ? 1 : 0 );
} else {
throw new Error( "Unsupported data type in shader variable: " + typeof input );
}
return output;
}
copyBuffersFromShader( sourceShader ) {
for ( const [ name, buffer ] of sourceShader.buffers.entries() ) {
this.buffers.set( name, buffer );
}
}
updateUniforms( data ) {
const arrayBuffer = new ArrayBuffer(12); // 3 x u32 = 3 * 4 bytes
const dataView = new DataView(arrayBuffer);
dataView.setUint32(0, data.k, true); // offset 0
dataView.setUint32(4, data.j, true); // offset 4
dataView.setUint32(8, data.totalCount, true); // offset 8
this.device.queue.writeBuffer(
this.uniformBuffer, // your GPUBuffer holding uniforms
0, // offset in buffer
arrayBuffer
);
}
createCommandEncoder( x, y, z ) {
const commandEncoder = this.device.createCommandEncoder({
label: this.path
});
const passEncoder = commandEncoder.beginComputePass({
label: this.path
});
if( !this.pipeline ) {
console.log("Pipeline missing, Maybe you forget to call addStage( ) ", this);
}
passEncoder.setPipeline( this.pipeline );
for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
passEncoder.dispatchWorkgroups( x, y, z );
passEncoder.end();
return commandEncoder;
}
async execute( x = 1, y = 1, z = 1 ) {
const commandEncoder = this.createCommandEncoder( x, y, z );
const commandBuffer = commandEncoder.finish();
this.device.queue.submit( [ commandBuffer ] );
await this.device.queue.onSubmittedWorkDone(); // critical!
}
registerBuffer( buffer, name, type ) {
if( !this.isInWorker ) {
if( !document.bufferMap[name] ) {
document.bufferMap[name] = new Array();
}
var bufferHistoryPoint = { buffer: buffer, shader: this, label : buffer.label, type: type };
document.bufferMap[name].push( bufferHistoryPoint );
}
}
getBuffer( name ) {
var buffer = this.buffers.get( name );
this.registerBuffer( buffer, name, "get" );
return buffer;
}
setBuffer( name, buffer ) {
this.registerBuffer( buffer, name, "set" );
if( !this.isInWorker ) {
//var bufferOwner = document.bufferOwners.get( buffer );
//console.log("setBuffer Origin:", bufferOwner, " this shader", this);
}
const bindingInfo = this.bindings.find( b => b.varName === name );
if( !bindingInfo ) {
console.error("No binding to buffer ", name, " in shader ", this );
}
this.buffers.set(name, buffer);
this.createBindGroups(); // always refresh bindings
}
async addStage( entryPoint, stage ) {
if ( !this.path ) {
throw new Error("Shader path is not set");
}
this.entryPoint = entryPoint;
if( stage != GPUShaderStage.COMPUTE ) {
this.stage = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
if ( !this.binded ) {
if( stage != GPUShaderStage.COMPUTE ) {
if( this.sampleCount == 1 ) {
this.depthTexture = this.device.createTexture({
size: {
width: this.canvas.width,
height: this.canvas.height,
depthOrArrayLayers: 1
},
sampleCount: 1, // Must be 1 for non-multisampled rendering
format: "depth24plus", // Or "depth24plus-stencil8" if stencil is needed
usage: GPUTextureUsage.RENDER_ATTACHMENT
});
} else {
this.multisampledColorTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount,
format: "bgra8unorm",
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
this.multisampledDepthTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount, // must match multisample count of color texture and pipeline
format: "depth24plus", // or "depth24plus-stencil8" if you need stencil
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
}
}
this.wgslSource = await this._loadShaderSource( this.path );
this._parseVertexAttributes( this.wgslSource );
this._parseBindings( this.wgslSource );
if ( this._externalBindGroupLayout ) {
this.bindGroupLayouts.set( 0, this._externalBindGroupLayout );
} else {
this.bindGroupLayouts = this._createBindGroupLayoutsFromBindings();
}
this._createAllBuffers();
this._createDefaultTexturesAndSamplers();
await this.createBindGroups();
this.binded = true;
}
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path
} );
if ( stage === GPUShaderStage.COMPUTE ) {
console.log("this.bindGroupLayouts", this.bindGroupLayouts);
console.log("create compute shader pipeline.");
this.computeStage = { module: shaderModule, entryPoint: this.entryPoint };
this.pipeline = this.device.createComputePipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
compute: this.computeStage
} );
await this.finalizeShader();
} else if ( stage === GPUShaderStage.VERTEX ) {
//this.vertexStage = { module: shaderModule, entryPoint: this.entryPoint };
console.log("Vertex buffers descriptors:", Array.from(this.vertexBuffersDescriptors.values()));
this.vertexStage = {
module: shaderModule,
entryPoint: this.entryPoint,
buffers: Array.from( this.vertexBuffersDescriptors.values() )
};
} else if ( stage === GPUShaderStage.FRAGMENT ) {
this.fragmentStage = { module: shaderModule, entryPoint: this.entryPoint };
}
// If both vertex and fragment stages exist, create render pipeline
if ( this.vertexStage && this.fragmentStage ) {
this._createAllBuffers();
console.log( this.device );
console.log("create fragment and vertex shader pipeline.");
console.log("setup createRenderPipeline", this.vertexStage);
console.log("this.device.createPipelineLayout", this.bindGroupLayouts.get(0) );
this.pipeline = this.device.createRenderPipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
vertex: this.vertexStage,
fragment: {
...this.fragmentStage,
targets: [
{
//format: "bgra8unorm",
format: this.canvas.getContext().getPreferredFormat(),
// No blending here for opaque rendering
blend: undefined,
writeMask: GPUColorWrite.ALL
}
] // Specify the color target format here
},
multisample: {
count: this.sampleCount, // number of samples per pixel, typical values: 4 or 8
},
primitive: {
topology: this.topologyValue,
frontFace: 'ccw' ,
cullMode: 'none',
},
depthStencil: {
format: "depth24plus", // <-- Add this to match render pass depthStencil
depthWriteEnabled: true,
depthCompare: "less",
},
} );
this._createAllBuffers();
await this.finalizeShader();
}
}
convertToTypedArrayWithPadding(array, preferredType = null) {
const isTypedArray = (
array instanceof Uint16Array ||
array instanceof Uint32Array ||
array instanceof Float32Array
);
let typedArray;
let arrayType = preferredType;
if (isTypedArray) {
if (preferredType === null) {
if (array instanceof Uint16Array) arrayType = "uint16";
else if (array instanceof Uint32Array) arrayType = "uint32";
else if (array instanceof Float32Array) arrayType = "float32";
}
const remainder = array.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = array.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
} else {
typedArray = array;
}
} else if (Array.isArray(array)) {
if (preferredType === null) {
// Default to Float32Array if no preferred type
arrayType = "float32";
} else {
arrayType = preferredType;
}
if (arrayType === "uint16") {
typedArray = new Uint16Array(array);
} else if (arrayType === "uint32") {
typedArray = new Uint32Array(array);
} else if (arrayType === "float32") {
typedArray = new Float32Array(array);
} else {
throw new Error("Unsupported preferred type");
}
const remainder = typedArray.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = typedArray.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
}
} else {
throw new Error("Input must be a TypedArray or Array of numbers");
}
return { typedArray, arrayType };
}
setIndices( indices ) {
console.log("indices", indices);
const isTypedArray = indices instanceof Uint16Array || indices instanceof Uint32Array;
let typedArray;
let arrayType;
console.log("isTypedArray", isTypedArray);
if ( isTypedArray ) {
arrayType = (indices instanceof Uint16Array) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding( indices, arrayType ) );
} else if (Array.isArray(indices)) {
let maxIndex = 0;
for (let i = 0; i < indices.length; i++) {
if (indices[i] > maxIndex) maxIndex = indices[i];
}
arrayType = (maxIndex < 65536) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding(indices, arrayType));
} else {
throw new Error("Indices must be Uint16Array, Uint32Array, or Array of numbers");
}
this.indexFormat = arrayType;
this.indexBufferData = typedArray;
if (this.indexBuffer) {
this.indexBuffer.destroy();
}
this.indexBuffer = this.device.createBuffer({
size: this.indexBufferData.byteLength,
usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST,
});
this.device.queue.writeBuffer(
this.indexBuffer,
0,
this.indexBufferData.buffer,
this.indexBufferData.byteOffset,
this.indexBufferData.byteLength
);
this.indexCount = this.indexBufferData.length;
}
setAttribute( name, array ) {
if ( this.attributeArrays === undefined ) {
this.attributeArrays = new Map();
}
const { typedArray, arrayType } = this.convertToTypedArrayWithPadding( array );
let attributeSize = 3; // Default size (vec3)
// this is not proper
if ( name === "uv" ) {
attributeSize = 2;
} else if ( name === "position" || name === "normal" ) {
attributeSize = 3;
} else {
// You can add more attribute names and sizes here
}
this.attributeArrays.set( name, { typedArray: typedArray, size: attributeSize } );
}
getAllAttributeBuffers() {
if ( !this.attributeBuffers ) {
return new Map();
}
return this.attributeBuffers;
}
vertexBuffers( passEncoder ) {
if ( this.attributeArrays === undefined ) {
console.warn("No attribute arrays to bind.");
return;
}
if ( this.mergedAttributeBuffer === undefined ) {
const attributeNames = Array.from( this.attributeArrays.keys() );
const arrays = attributeNames.map( name => this.attributeArrays.get( name ).typedArray );
const attributeSizes = attributeNames.map( name => (name === "uv") ? 2 : 3 );
const vertexCount = arrays[0].length / attributeSizes[0];
const strideFloats = attributeSizes.reduce( (sum, size) => sum + size, 0 );
const mergedArray = new Float32Array( vertexCount * strideFloats );
for ( let i = 0; i < vertexCount; i++ ) {
let offset = i * strideFloats;
for ( let attrIndex = 0; attrIndex < arrays.length; attrIndex++ ) {
const arr = arrays[attrIndex];
const size = attributeSizes[attrIndex];
const baseIndex = i * size;
for ( let c = 0; c < size; c++ ) {
mergedArray[ offset++ ] = arr[ baseIndex + c ];
}
}
}
const byteLength = mergedArray.byteLength;
const buffer = this.device.createBuffer({
size : byteLength,
usage : GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation : true,
label : "mergedAttributes"
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( mergedArray );
buffer.unmap();
this.device.queue.writeBuffer(
buffer,
0,
mergedArray.buffer,
mergedArray.byteOffset,
mergedArray.byteLength
);
this.mergedAttributeBuffer = buffer;
// Store for pipeline setup
this.attributeStrideFloats = strideFloats;
this.attributeSizes = attributeSizes;
this.attributeNames = attributeNames;
}
// Bind the merged buffer at slot 0
passEncoder.setVertexBuffer( 0, this.mergedAttributeBuffer );
}
setCanvas( canvas ) {
this.canvas = canvas;
}
async renderToCanvas( x, y = 0, z = 0, test ) {
const canvas = this.canvas;
const context = canvas.getContext("webgpu");
const commandEncoder = this.device.createCommandEncoder()
var renderPassDescriptor = {
colorAttachments: [{
view: context.getCurrentTextureView(),
loadOp: "clear",
storeOp: "store",
clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1 }
}]
,
depthStencilAttachment: {
view: this.depthTexture.createView(),
depthLoadOp: "clear",
depthStoreOp: "store",
depthClearValue: 1.0,
}
};
const renderPass = commandEncoder.beginRenderPass(renderPassDescriptor)
renderPass.setPipeline(this.pipeline)
//renderPass.setViewport(0, 0, canvas.width, canvas.height, 0, 1)
//renderPass.setScissorRect(0, 0, canvas.width, canvas.height)
//renderPass.setVertexBuffer(0, positionBuffer)
//renderPass.setVertexBuffer(1, colorBuffer)
//renderPass.setIndexBuffer(indexBuffer, 'uint16')
//renderPass.drawIndexed(3)
this.vertexBuffers( renderPass );
for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) {
renderPass.setBindGroup( bindingGroupIndex, bindGroup );
}
if( this.indexBuffer ) {
renderPass.setIndexBuffer( this.indexBuffer, this.indexFormat );
renderPass.drawIndexed( this.indexCount, y, z );
} else {
//console.log(x, y, z);
renderPass.draw( x, y, z );
}
renderPass.end()
this.device.queue.submit([ commandEncoder.finish() ])
return;
/*
const commandEncoder = this.device.createCommandEncoder();
var renderPassDescriptor;
if (this.sampleCount === 1) {
renderPassDescriptor = {
colorAttachments: [{
view: context.getCurrentTexture().createView(),
loadOp: "clear",
storeOp: "store",
clearValue: { r: 0.3, g: 0.3, b: 0.3, a: 1 }
}],
depthStencilAttachment: {
view: this.depthTexture.createView(),
depthLoadOp: "clear",
depthStoreOp: "store",
depthClearValue: 1.0,
}
};
} else {
renderPassDescriptor = {
colorAttachments: [{
view: this.multisampledColorTexture.createView(),
resolveTarget: context.getCurrentTexture().createView(),
loadOp: "clear",
storeOp: "store",
clearValue: { r: 0, g: 0, b: 0, a: 1 }
}],
depthStencilAttachment: {
view: this.multisampledDepthTexture.createView(),
depthLoadOp: "clear",
depthStoreOp: "store",
depthClearValue: 1.0,
}
};
}
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(this.pipeline);
this.vertexBuffers( passEncoder );
for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
if( this.indexBuffer ) {
passEncoder.setIndexBuffer( this.indexBuffer, this.indexFormat );
passEncoder.drawIndexed( this.indexCount, y, z );
} else {
//console.log(x, y, z);
passEncoder.draw( x, y, z );
}
passEncoder.end();
this.device.queue.submit([commandEncoder.finish()]);
*/
}
async finalizeShader() {
//await this._createPipeline();
await this.createBindGroups();
}
async debugBuffer( bufferName, bufferType ) {
const buffer = this.getBuffer( bufferName );
const dataSize = buffer.size;
const debugReadBuffer = this.device.createBuffer( {
size: dataSize,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
} );
const debugEncoder = this.device.createCommandEncoder();
console.log( "print buffer", buffer );
debugEncoder.copyBufferToBuffer(
buffer,
0,
debugReadBuffer,
0,
dataSize
);
this.device.queue.submit( [ debugEncoder.finish() ] );
await this.device.queue.onSubmittedWorkDone();
await debugReadBuffer.mapAsync( GPUMapMode.READ );
const copyArrayBuffer = debugReadBuffer.getMappedRange();
let result;
const bindingInfo = this.bindings.find( b => b.varName === bufferName );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
//console.log("wgslType", bindingInfo, wgslType, bufferName);
switch ( wgslType ) {
case "array<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "array<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "vec3<f32>":
break;
case "vec4<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "array<vec4<f32>>":
case "array<Vector>":
case "array<array<f32, 64>>":
case "array<vec3<f32>>":
result = new Float32Array( copyArrayBuffer );
break;
case "vec3<i32>":
case "vec4<i32>":
result = new Int32Array( copyArrayBuffer );
break;
default:
console.error( "Unknown WGSL type in setVariable: " + wgslType, "Using Float32Array as fallback." );
result = new Float32Array( copyArrayBuffer );
}
// use wgslType here to determine appropriate TypedArray
}
const length = result.length;
var regularArray = new Array();
for (var i = 0; i < length; i++) {
regularArray.push(result[i]);
}
console.log("Debugging array");
console.log("");
console.log("Shader: ", this.path);
console.log("Variable Name: ", bufferName);
console.log( "Buffer Size: ", this.getBuffer(bufferName).size );
console.log( "TypedArray Size: ", length );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
console.log( "type: ", wgslType );
}
console.log(regularArray);
debugReadBuffer.unmap();
return regularArray;
}
}