Files
WebGPU-FFT-Ocean-Demo/framework/WebGpu.js

2257 lines
46 KiB
JavaScript
Raw Normal View History

2025-12-31 14:22:45 +01:00
if( typeof document != "undefined" ) {
document.shaders = new Array();
document.bufferMap = {};
}
export default class Shader {
sampleCount = 1;
device;
path;
wgslSource;
bindGroupLayout;
bindGroupLayouts = new Map();
vertexBuffersDescriptors = new Map();
textures = new Map();
pipeline;
bindGroups = new Map();
entryPoint = "main";
bindings = new Array();
buffers = new Map();
bufferUsages = new Map();
bufferSizes = new Map();
attributeBuffers ;
bufferUsages = new Map(); // store usage per buffer to recreate correctly
_externalBindGroupLayout = null;
hasLoaded = false;
stage = GPUShaderStage.COMPUTE;
isInWorker = typeof document === "undefined";
topologyValue = "triangle-strip";
binded = false;
clearOnRender = true;
textures = new Map();
samplers = new Map();
constructor( device, path = false ) {
if( !this.isInWorker ) {
document.shaders.push( this );
}
console.log("\n\n \n Creating New Shader", path, "\n\n");
console.log("device", typeof device);
this.device = device;
if( path ) {
this.path = path;
}
}
set topology(value) {
if (this.binded) {
console.error("Define topology before setup", this);
}
this.topologyValue = value;
}
get topology() {
return this.topologyValue;
}
_createAllBuffers() {
for (const b of this.bindings) {
if (this.buffers.has(b.varName)) continue;
let size;
if (b.type === "uniform") {
size = 256; // uniforms must be 256 bytes
}
else if (b.type === "read-only-storage" || b.type === "storage") {
size = 16; // minimum valid storage buffer size
}
else {
size = 16; // safe fallback
}
let usage;
switch (b.type) {
case "uniform":
usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
break;
case "read-only-storage":
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
case "storage":
default:
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
}
const buffer = this.device.createBuffer({
size,
usage,
label: b.varName
});
this.buffers.set(b.varName, buffer);
this.bufferUsages.set(b.varName, usage);
this.bufferSizes.set(b.varName, size);
}
}
_parseStructs(wgsl) {
this.structs = {};
const structRegex =
/struct\s+(\w+)\s*\{([^}]+)\}/g;
let match;
while ((match = structRegex.exec(wgsl)) !== null) {
const structName = match[1];
const body = match[2].trim();
const fields = [];
const fieldRegex =
/(\w+)\s*:\s*([\w<>\[\]]+)\s*,?/g;
let f;
while ((f = fieldRegex.exec(body)) !== null) {
const fieldName = f[1];
const fieldType = f[2];
fields.push({
name: fieldName,
type: fieldType
});
}
this.structs[structName] = {
name: structName,
fields: fields,
size: 0 // computed later
};
}
for (const name of Object.keys(this.structs)) {
this._computeStructSize(name);
}
console.log("structs reflection created:", this.structs);
}
_computeStructSize(structName) {
const struct = this.structs[structName];
if (!struct) return 0;
let offset = 0;
for (const f of struct.fields) {
const align = 4;
const size = 4;
// align field offset
offset = Math.ceil(offset / align) * align;
f.offset = offset;
f.size = size;
f.align = align;
offset += size;
}
struct.size = offset;
return offset;
}
_computeTypeSize(wgslType) {
if (wgslType === "f32" || wgslType === "i32" || wgslType === "u32") {
return 4;
}
// vecN<T>
const vecMatch = wgslType.match(/vec(\d+)<(\w+)>/);
if (vecMatch) {
const n = Number(vecMatch[1]);
const elemType = vecMatch[2];
return n * this._computeTypeSize(elemType);
}
// matCxR<T> (e.g. mat4x4<f32>)
const matMatch = wgslType.match(/mat(\d+)x(\d+)<(\w+)>/);
if (matMatch) {
const c = Number(matMatch[1]);
const r = Number(matMatch[2]);
const elemType = matMatch[3];
return c * r * this._computeTypeSize(elemType);
}
// array<T, N>
const arrFix = wgslType.match(/array<(\w+),\s*(\d+)>/);
if (arrFix) {
const element = arrFix[1];
const count = Number(arrFix[2]);
return this._computeTypeSize(element) * count;
}
// array<T> runtime sized
const arrDyn = wgslType.match(/array<(\w+)>/);
if (arrDyn) {
return 4;
}
// struct
if (this.structs[wgslType]) {
return this.structs[wgslType].size;
}
return 4;
}
_computeBindingSizes() {
for (const b of this.bindings) {
const t = b.varType;
// -----------------------------------------
// 1. Struct?
// -----------------------------------------
if (this.structs[t]) {
b.size = this.structs[t].size;
continue;
}
// -----------------------------------------
// 4. Texture / sampler (no buffer)
// -----------------------------------------
if (b.type === "texture" || b.type === "sampler") {
b.size = 0;
continue;
}
// -----------------------------------------
// 5. Use computed WGSL type size (scalars, vectors, matrices, arrays)
// -----------------------------------------
b.size = this._computeTypeSize(t);
}
}
_createAllBuffers() {
for (const b of this.bindings) {
if (this.buffers.has(b.varName))
continue;
let size = b.size;
if (size < 4) size = 4;
let usage;
switch (b.type) {
case "uniform":
usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
break;
case "read-only-storage":
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
default:
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
const buffer = this.device.createBuffer({
size,
usage,
label: b.varName
});
this.buffers.set(b.varName, buffer);
this.bufferSizes.set(b.varName, size);
this.bufferUsages.set(b.varName, usage);
}
}
_createDefaultTexturesAndSamplers() {
for ( const b of this.bindings ) {
if ( b.type === "texture" && !this.textures.has( b.varName ) ) {
// Detect if the type is a texture array
const isTextureArray = b.varType.startsWith("texture_2d_array");
// Create 1x1 texture with 1 layer (depthOrArrayLayers = 1)
const defaultTexture = this.device.createTexture({
size: isTextureArray ? [1, 1, 2] : [1, 1, 1],
format: "rgba8unorm",
usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST,
label: b.varName + "_defaultTexture"
});
console.log("create texture ", isTextureArray, isTextureArray ? "2d-array" : "2d");
// Create texture view with correct dimension
const defaultTextureView = defaultTexture.createView({
dimension: isTextureArray ? "2d-array" : "2d"
});
console.log("create default texture?", b, defaultTextureView, isTextureArray ? "2d-array" : "2d");
this.textures.set(b.varName, defaultTextureView);
} else if ( b.type === "sampler" && !this.samplers.has( b.varName ) ) {
// Create default sampler
const defaultSampler = this.device.createSampler( {
magFilter: "linear",
minFilter: "linear",
label: b.varName + "_defaultSampler"
} );
this.samplers.set( b.varName, defaultSampler );
}
}
}
createBindGroups() {
this.bindGroups = new Map();
for ( const [ group, layout ] of this.bindGroupLayouts.entries() ) {
const bindings = this.bindings.filter( b => b.group === group );
const entries = new Array( bindings.length );
for ( let i = 0; i < bindings.length; i++ ) {
const b = bindings[ i ];
let resource;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
const gpuBuffer = this.buffers.get( b.varName );
if ( !gpuBuffer ) {
throw new Error( `Buffer missing for ${b.varName} when creating bind group.` );
}
resource = { buffer: gpuBuffer };
} else if ( b.type === "texture" ) {
const gpuTextureView = this.textures.get( b.varName );
//console.log("this.textures", this.textures);
if ( !gpuTextureView ) {
console.warn( `Texture missing for ${b.varName} when creating bind group.` );
} else {
//console.log( `Binding texture ${b.varName} with dimension:`, gpuTextureView.dimension );
}
resource = gpuTextureView;
} else if ( b.type === "sampler" ) {
const gpuSampler = this.samplers.get( b.varName );
if ( !gpuSampler ) {
//throw new Error( `Sampler missing for ${b.varName} when creating bind group.` );
}
resource = gpuSampler;
} else {
throw new Error( `Unknown binding type '${b.type}' for ${b.varName}` );
}
entries[ i ] = {
binding: b.binding,
resource: resource
};
}
const bindGroup = this.device.createBindGroup( {
layout: layout,
entries: entries
} );
this.bindGroups.set( group, bindGroup );
}
}
extractEntryPoints( wgslCode ) {
const entryPoints = [];
// Regex to match lines like: @vertex fn functionName(...) { ... }
const regex =
/@(\w+)(?:\s*@[\w_]+(?:\([^)]*\))?)*\s*fn\s+([\w_]+)\s*\(/gm;
let match;
while ((match = regex.exec(wgslCode)) !== null) {
const stage = match[1];
const name = match[2];
let type;
switch (stage) {
case "vertex":
type = GPUShaderStage.VERTEX;
break;
case "fragment":
type = GPUShaderStage.FRAGMENT;
break;
case "compute":
type = GPUShaderStage.COMPUTE;
break;
default:
type = "Unknown";
}
entryPoints.push({ name: name, type: type });
}
return entryPoints;
}
async setup( path ) {
this.path = path;
this.wgslSource = await this._loadShaderSource( this.path );
// 2. Parse metadata from WGSL
this._parseVertexAttributes(this.wgslSource);
this._parseBindings(this.wgslSource);
this._parseStructs(this.wgslSource);
// 3. Compute struct sizes BEFORE creating buffers
this._computeBindingSizes(); // if you implemented this
// ----------------------------------------------------------
// 4. CREATE BUFFERS BEFORE EVERYTHING ELSE
// ----------------------------------------------------------
this._createAllBuffers();
//console.log(this.wgslSource);
var entryPoints = this.extractEntryPoints( this.wgslSource );
console.log(entryPoints);
for (var i = 0; i < entryPoints.length; i++) {
var entryPoint = entryPoints[i];
await this.addStage( entryPoint.name, entryPoint.type );
}
//await this.addStage("computeMain", 4 );
//console.log( "load shader", this.wgslSource, entryPoints )
}
async _loadShaderSource( pathName ) {
const response = await fetch( pathName );
if ( !response.ok ) throw new Error( `Failed to load shader: ${ pathName }` );
return await response.text();
}
setVertexBuffer( name, dataArray ) {
const buffer = this.device.createBuffer({
size: dataArray.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation: true
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( dataArray );
buffer.unmap();
this.buffers.set( name, buffer );
}
extractVertexFunctionParameters(wgsl) {
let vertexFnIndex = wgsl.indexOf('@vertex');
if (vertexFnIndex === -1) return '';
let fnIndex = wgsl.indexOf('fn', vertexFnIndex);
if (fnIndex === -1) return '';
let parenStart = wgsl.indexOf('(', fnIndex);
if (parenStart === -1) return '';
let parenCount = 1;
let i = parenStart + 1;
while (i < wgsl.length && parenCount > 0) {
const c = wgsl[i];
if (c === '(') parenCount++;
else if (c === ')') parenCount--;
i++;
}
if (parenCount !== 0) {
// unbalanced parentheses
return '';
}
// Extract parameter substring between parenStart+1 and i-1
let paramsStr = wgsl.substring(parenStart + 1, i - 1);
return paramsStr.trim();
}
extractVertexParamsLocations(wgsl) {
const paramsStr = this.extractVertexFunctionParameters(wgsl)
const paramRegex = /@location\s*\(\s*(\d+)\s*\)\s+(\w+)\s*:\s*([\w<>\d]+(\s*<[^>]+>)?)/g;
const results = new Array();
let paramMatch;
while ((paramMatch = paramRegex.exec(paramsStr)) !== null) {
const location = Number(paramMatch[1]);
const varName = paramMatch[2];
const type = paramMatch[3].trim();
results.push({ location, varName, type });
}
return results;
}
getFormatSize( format ) {
switch ( format ) {
case "float32": return 4;
case "float32x2": return 8;
case "float32x3": return 12;
case "float32x4": return 16;
case "uint32": return 4;
case "sint32": return 4;
default: throw new Error("Unsupported format size for: " + format);
}
}
mapTypeToVertexFormat( type ) {
// Map WGSL types to GPUVertexFormat strings (example subset)
switch ( type ) {
case "f32": return "float32";
case "vec2<f32>": return "float32x2";
case "vec3<f32>": return "float32x3";
case "vec4<f32>": return "float32x4";
case "u32": return "uint32";
case "i32": return "sint32";
default: throw new Error("Unsupported vertex attribute type: " + type);
}
}
_parseVertexAttributes( wgsl ) {
const locations = this.extractVertexParamsLocations( wgsl );
if (locations.length === 0) {
// No vertex attributes = no vertex buffer descriptors needed
return;
}
let offset = 0;
const attributes = new Array();
for ( let attr of locations ) {
const format = this.mapTypeToVertexFormat( attr.type );
attributes.push({
shaderLocation : attr.location,
offset : offset,
format : format
});
offset += this.getFormatSize( format );
}
const vertexBufferDescriptor = {
arrayStride: offset,
attributes: attributes,
stepMode: "vertex"
};
console.log("vertexBufferDescriptor", vertexBufferDescriptor);
this.vertexBuffersDescriptors.set("vertexBuffer0", vertexBufferDescriptor);
console.log("vertex buffer descriptors updated:", this.vertexBuffersDescriptors);
}
_parseBindings( wgsl ) {
console.log("parse bindings");
this.bindings = [];
// Make the `<...>` part optional by wrapping it in a non-capturing group with `?`
const regex = /@group\((\d+)\)\s+@binding\((\d+)\)\s+var(?:<(\w+)(?:,\s*(\w+))?>)?\s+(\w+)\s*:\s*([^;]+);/g;
let match;
while ((match = regex.exec(wgsl)) !== null) {
const group = Number(match[1]);
const binding = Number(match[2]);
const storageClass = match[3]; // e.g. "uniform", "storage", or undefined for textures/samplers
const access = match[4]; // e.g. "read", "write", or undefined
const varName = match[5];
const varType = match[6]; // WGSL type string like "mat4x4<f32>", "texture_2d<f32>", "sampler", etc.
let type;
if (storageClass === "storage") {
type = access === "read" ? "read-only-storage" : "storage";
} else if (storageClass === "uniform") {
type = "uniform";
} else if (varType.startsWith("texture_")) {
type = "texture";
} else if (varType === "sampler" || varType === "sampler_comparison") {
type = "sampler";
} else {
type = "uniform"; // fallback
}
this.bindings.push({ group, binding, type, varName, varType });
}
console.log("bindings", this.bindings);
this.bindings.sort((a, b) => a.binding - b.binding);
}
_createBindGroupLayoutsFromBindings() {
const groups = new Map();
for ( let i = 0; i < this.bindings.length; i++ ) {
const b = this.bindings[ i ];
if ( !groups.has( b.group ) ) groups.set( b.group, [] );
groups.get( b.group ).push( b );
}
const layouts = new Map();
for ( const groupEntry of groups.entries() ) {
const group = groupEntry[ 0 ];
const bindings = groupEntry[ 1 ];
let visibility;
if ( this.stage === GPUShaderStage.COMPUTE ) {
visibility = GPUShaderStage.COMPUTE;
} else {
visibility = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
const entries = new Array( bindings.length );
for ( let j = 0; j < bindings.length; j++ ) {
const b = bindings[ j ];
let entry;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: b.type }
};
} else if ( b.type === "texture" ) {
let viewDimension = "2d"; // default
if ( b.varType.startsWith("texture_2d_array") ) {
viewDimension = "2d-array";
} else if ( b.varType.startsWith("texture_cube") ) {
viewDimension = "cube";
} else if ( b.varType.startsWith("texture_3d") ) {
viewDimension = "3d";
}
console.log("viewDimension", viewDimension);
entry = {
binding: b.binding,
visibility: visibility,
texture: {
sampleType: "float",
viewDimension: viewDimension,
multisampled: false
}
};
} else if ( b.type === "sampler" ) {
entry = {
binding: b.binding,
visibility: visibility,
sampler: { type: "filtering" }
};
} else {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: "uniform" }
};
}
entries[ j ] = entry;
}
console.log("createBindGroupLayout ", entries);
const layout = this.device.createBindGroupLayout( { entries } );
layouts.set( group, layout );
}
return layouts;
}
async _createPipeline() {
const layoutsArray = this.bindGroupLayouts && this.bindGroupLayouts.size > 0
? Array.from( this.bindGroupLayouts.values() )
: [ this.bindGroupLayout ];
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path,
} );
const info = await shaderModule.compilationInfo();
if (info.messages.length > 0) {
for (const msg of info.messages) {
console[msg.type === "error" ? "error" : "warn"](
`[${msg.lineNum}:${msg.linePos}] ${msg.message}`
);
}
}
//console.log("this.bindGroupLayouts.values()", this.bindGroupLayouts.values());
this.pipeline = this.device.createComputePipeline({
layout: this.device.createPipelineLayout({
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
}),
compute: {
module: shaderModule,
entryPoint: this.entryPoint,
},
} );
}
setBindingGroupLayout( bindGroupLayout ) {
this._externalBindGroupLayout = bindGroupLayout;
}
getBindingGroupLayout( bindingGroupIndex = 0 ) {
if ( this.bindGroupLayouts ) {
return this.bindGroupLayouts.get( bindingGroupIndex );
}
return this.bindGroupLayout;
}
getTypedArrayByVariableName( name, flatData ) {
let typedArray;
const bindingInfo = this.bindings.find( b => b.varName === name );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
switch( wgslType ) {
case "vec3<f32>":
// flatData.push( 0 );
break;
}
switch ( wgslType ) {
case "f32":
case "vec2<f32>":
case "vec3<f32>":
case "vec4<f32>":
case "array<f32>":
typedArray = new Float32Array( flatData );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
case "array<u32>":
typedArray = new Uint32Array( flatData );
break;
case "i32":
case "vec2<i32>":
case "vec3<i32>":
case "vec4<i32>":
case "array<i32>":
typedArray = new Int32Array( flatData );
break;
default:
// find struct and then the datatype
typedArray = new Float32Array( flatData );
//console.error( "Unknown WGSL type in setVariable: " + wgslType );
}
//console.log("wgslType", wgslType);
}
return typedArray;
}
createTextureFromData(width, height, data = null) {
// 1. Infer format based on TypedArray
let format;
let bytesPerComponent;
if (data instanceof Float32Array) {
format = "rgba32float";
bytesPerComponent = 4;
}
else if (data instanceof Uint8Array) {
format = "rgba8unorm";
bytesPerComponent = 1;
}
else if (data instanceof Uint32Array) {
format = "rgba32uint";
bytesPerComponent = 4;
}
else if (data instanceof Int32Array) {
format = "rgba32sint";
bytesPerComponent = 4;
}
else {
throw new Error("Unsupported texture data type");
}
// 2. Compute expected size
const channels = 4;
const expected = width * height * channels;
if (data && data.length !== expected) {
throw new Error(
`Texture data length mismatch. Expected ${expected}, got ${data.length}`
);
}
// 3. Create texture
const texture = this.device.createTexture({
size: { width, height, depthOrArrayLayers: 1 },
format,
usage:
GPUTextureUsage.TEXTURE_BINDING |
GPUTextureUsage.STORAGE_BINDING |
GPUTextureUsage.COPY_DST
});
// 4. Upload if data provided
if (data) {
this.device.queue.writeTexture(
{
texture: texture
},
data,
{
bytesPerRow: width * channels * bytesPerComponent
},
{
width,
height,
depthOrArrayLayers: 1
}
);
}
return texture;
}
setVariable( name, dataObject ) {
if (dataObject instanceof GPUBuffer) {
// Direct buffer binding do NOT recreate, resize, or write.
this.setBuffer( name, dataObject );
return;
}
// TEXTURES -----------------------------------------------------
if (dataObject instanceof GPUTexture) {
this.textures.set(name, dataObject.createView());
this._invalidateBindGroups();
return;
}
// SAMPLERS -----------------------------------------------------
if (dataObject instanceof GPUSampler) {
this.samplers.set(name, dataObject);
this._invalidateBindGroups();
return;
}
// STRUCT CHECK -------------------------------------------------
const binding = this.bindings.find(b => b.varName === name);
const structInfo = binding ? this.structs[binding.varType] : undefined;
let typed;
// --------------------------------------------------------------
// 1. WGSL STRUCT → Uint32Array (using struct reflection)
// --------------------------------------------------------------
if (structInfo) {
const out = new Uint32Array(structInfo.size / 4); // aligned struct size in bytes
for (let i = 0; i < structInfo.fields.length; i++) {
const f = structInfo.fields[i];
let v = dataObject[f.name];
if (v === undefined) {
v = 0;
}
// only scalar u32/f32 fields for now
out[f.offset / 4] = v >>> 0;
}
typed = out;
}
// --------------------------------------------------------------
// 2. Already typed arrays
// --------------------------------------------------------------
else if (
dataObject instanceof Float32Array ||
dataObject instanceof Uint32Array ||
dataObject instanceof Int32Array
) {
typed = dataObject;
}
// --------------------------------------------------------------
// 3. Number or JS array → detect type
// --------------------------------------------------------------
else if (typeof dataObject === "number" || Array.isArray(dataObject)) {
const flat = this.flattenStructure(dataObject);
typed = this.getTypedArrayByVariableName(name, flat);
}
// --------------------------------------------------------------
// 4. Raw ArrayBuffer
// --------------------------------------------------------------
else if (dataObject instanceof ArrayBuffer) {
typed = new Float32Array(dataObject);
}
// --------------------------------------------------------------
// 5. Fallback
// --------------------------------------------------------------
else {
const flat = this.flattenStructure(dataObject);
typed = new Float32Array(flat);
}
// BUFFER CREATION / RESIZE -------------------------------------
const bytes = typed.byteLength;
const oldBuffer = this.buffers.get(name);
const oldUsage = this.bufferUsages.get(name);
const oldSize = oldBuffer ? oldBuffer.size : 0;
if (typeof oldUsage !== "number") {
throw new Error("Invalid buffer usage for '" + name + "'");
}
let buffer = oldBuffer;
// Resize if needed
if (!buffer || oldSize !== bytes) {
buffer = this.device.createBuffer({
label: name,
size: bytes,
usage: oldUsage,
mappedAtCreation: false
});
this.buffers.set(name, buffer);
this.bufferSizes.set(name, bytes);
this._invalidateBindGroups(); // <— IMPORTANT FIX
}
// BUFFER UPLOAD -------------------------------------------------
this.device.queue.writeBuffer(
buffer,
0,
typed.buffer,
typed.byteOffset,
typed.byteLength
);
}
_invalidateBindGroups() {
this.bindGroups = false;
}
computeUsage(name) {
const binding = this.bindings.find(b => b.varName === name);
if (!binding) {
// No binding found → use safe fallback
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
switch (binding.type) {
case "uniform":
return GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
case "read-only-storage":
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ;
case "storage":
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
default:
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
}
flattenStructure( input, output = [] ) {
if ( Array.isArray( input ) ) {
for ( const element of input ) {
this.flattenStructure( element, output );
}
} else if ( typeof input === "object" && input !== null ) {
for ( const key of Object.keys( input ) ) {
this.flattenStructure( input[ key ], output );
}
} else if ( typeof input === "number" ) {
output.push( input );
} else if ( typeof input === "boolean" ) {
output.push( input ? 1 : 0 );
} else {
throw new Error( "Unsupported data type in shader variable: " + typeof input );
}
return output;
}
copyBuffersFromShader( sourceShader ) {
for ( const [ name, buffer ] of sourceShader.buffers.entries() ) {
this.buffers.set( name, buffer );
}
}
updateUniforms( data ) {
const arrayBuffer = new ArrayBuffer(12); // 3 x u32 = 3 * 4 bytes
const dataView = new DataView(arrayBuffer);
dataView.setUint32(0, data.k, true); // offset 0
dataView.setUint32(4, data.j, true); // offset 4
dataView.setUint32(8, data.totalCount, true); // offset 8
this.device.queue.writeBuffer(
this.uniformBuffer, // your GPUBuffer holding uniforms
0, // offset in buffer
arrayBuffer
);
}
createCommandEncoder( x, y, z ) {
// AUTO-REBUILD BIND GROUPS IF NEEDED
if (!this.bindGroups || this.bindGroups.size === 0) {
this.createBindGroups();
}
const commandEncoder = this.device.createCommandEncoder({
label: this.path
});
const passEncoder = commandEncoder.beginComputePass({
label: this.path
});
if( !this.pipeline ) {
console.log("Pipeline missing, Maybe you forget to call addStage( ) ", this);
}
passEncoder.setPipeline( this.pipeline );
for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
passEncoder.dispatchWorkgroups( x, y, z );
passEncoder.end();
return commandEncoder;
}
async execute( x = 1, y = 1, z = 1 ) {
const commandEncoder = this.createCommandEncoder( x, y, z );
const commandBuffer = commandEncoder.finish();
this.device.queue.submit( [ commandBuffer ] );
await this.device.queue.onSubmittedWorkDone(); // critical!
}
registerBuffer( buffer, name, type ) {
if( !this.isInWorker ) {
if( !document.bufferMap[name] ) {
document.bufferMap[name] = new Array();
}
var bufferHistoryPoint = { buffer: buffer, shader: this, label : buffer.label, type: type };
document.bufferMap[name].push( bufferHistoryPoint );
}
}
getBuffer( name ) {
// If no bind groups yet, create them now
if (!this.bindGroups || this.bindGroups.size === 0) {
this.createBindGroups();
}
var buffer = this.buffers.get( name );
this.registerBuffer( buffer, name, "get" );
return buffer;
}
async setBuffer( name, buffer ) {
this.registerBuffer( buffer, name, "set" );
if( !this.isInWorker ) {
//var bufferOwner = document.bufferOwners.get( buffer );
//console.log("setBuffer Origin:", bufferOwner, " this shader", this);
}
const bindingInfo = this.bindings.find( b => b.varName === name );
if( !bindingInfo ) {
console.error("No binding to buffer ", name, " in shader ", this );
}
//console.log("set buffer", name, buffer);
this.buffers.set(name, buffer);
await this.createBindGroups(); // always refresh bindings
}
async addStage( entryPoint, stage ) {
if ( !this.path ) {
throw new Error("Shader path is not set");
}
this.entryPoint = entryPoint;
if( stage != GPUShaderStage.COMPUTE ) {
this.stage = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
if ( !this.binded ) {
if( stage != GPUShaderStage.COMPUTE ) {
if( this.sampleCount == 1 ) {
this.depthTexture = this.device.createTexture({
size: {
width: this.canvas.width,
height: this.canvas.height,
depthOrArrayLayers: 1
},
sampleCount: 1, // Must be 1 for non-multisampled rendering
format: "depth24plus", // Or "depth24plus-stencil8" if stencil is needed
usage: GPUTextureUsage.RENDER_ATTACHMENT
});
} else {
this.multisampledColorTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount,
format: "bgra8unorm",
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
this.multisampledDepthTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount, // must match multisample count of color texture and pipeline
format: "depth24plus", // or "depth24plus-stencil8" if you need stencil
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
}
}
this.wgslSource = await this._loadShaderSource( this.path );
//this._parseVertexAttributes( this.wgslSource );
//this._parseBindings( this.wgslSource );
//this._parseStructs( this.wgslSource );
if ( this._externalBindGroupLayout ) {
this.bindGroupLayouts.set( 0, this._externalBindGroupLayout );
} else {
this.bindGroupLayouts = this._createBindGroupLayoutsFromBindings();
}
this._createDefaultTexturesAndSamplers();
await this.createBindGroups();
this.binded = true;
}
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path
} );
if ( stage === GPUShaderStage.COMPUTE ) {
//console.log("this.bindGroupLayouts", this.bindGroupLayouts);
//console.log("create compute shader pipeline.");
this.computeStage = { module: shaderModule, entryPoint: this.entryPoint };
this.pipeline = this.device.createComputePipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
compute: this.computeStage
} );
await this.finalizeShader();
} else if ( stage === GPUShaderStage.VERTEX ) {
//this.vertexStage = { module: shaderModule, entryPoint: this.entryPoint };
console.log("Vertex buffers descriptors:", Array.from(this.vertexBuffersDescriptors.values()));
const vertexBuffers = this.vertexBuffersDescriptors.size > 0
? [...this.vertexBuffersDescriptors.values()]
: [];
this.vertexStage = {
module: shaderModule,
entryPoint: this.entryPoint,
buffers: vertexBuffers
};
} else if ( stage === GPUShaderStage.FRAGMENT ) {
this.fragmentStage = { module: shaderModule, entryPoint: this.entryPoint };
}
// If both vertex and fragment stages exist, create render pipeline
if ( this.vertexStage && this.fragmentStage ) {
this._createAllBuffers();
console.log( this.device );
console.log("create fragment and vertex shader pipeline.");
console.log("setup createRenderPipeline", this.vertexStage);
console.log("this.device.createPipelineLayout", this.bindGroupLayouts.get(0) );
this.pipeline = this.device.createRenderPipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
vertex: this.vertexStage,
fragment: {
...this.fragmentStage,
targets: [
{
format: "bgra8unorm",
// No blending here for opaque rendering
blend: undefined,
writeMask: GPUColorWrite.ALL
}
] // Specify the color target format here
},
multisample: {
count: this.sampleCount, // number of samples per pixel, typical values: 4 or 8
},
primitive: {
topology: this.topologyValue,
frontFace: 'ccw' ,
cullMode: 'none',
},
depthStencil: {
format: "depth24plus", // <-- Add this to match render pass depthStencil
depthWriteEnabled: true,
depthCompare: "less",
},
} );
this._createAllBuffers();
await this.finalizeShader();
}
}
convertToTypedArrayWithPadding(array, preferredType = null) {
const isTypedArray = (
array instanceof Uint16Array ||
array instanceof Uint32Array ||
array instanceof Float32Array
);
let typedArray;
let arrayType = preferredType;
if (isTypedArray) {
if (preferredType === null) {
if (array instanceof Uint16Array) arrayType = "uint16";
else if (array instanceof Uint32Array) arrayType = "uint32";
else if (array instanceof Float32Array) arrayType = "float32";
}
const remainder = array.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = array.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
} else {
typedArray = array;
}
} else if (Array.isArray(array)) {
if (preferredType === null) {
// Default to Float32Array if no preferred type
arrayType = "float32";
} else {
arrayType = preferredType;
}
if (arrayType === "uint16") {
typedArray = new Uint16Array(array);
} else if (arrayType === "uint32") {
typedArray = new Uint32Array(array);
} else if (arrayType === "float32") {
typedArray = new Float32Array(array);
} else {
throw new Error("Unsupported preferred type");
}
const remainder = typedArray.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = typedArray.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
}
} else {
throw new Error("Input must be a TypedArray or Array of numbers");
}
return { typedArray, arrayType };
}
setIndices( indices ) {
console.log("indices", indices);
const isTypedArray = indices instanceof Uint16Array || indices instanceof Uint32Array;
let typedArray;
let arrayType;
console.log("isTypedArray", isTypedArray);
if ( isTypedArray ) {
arrayType = (indices instanceof Uint16Array) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding( indices, arrayType ) );
} else if (Array.isArray(indices)) {
let maxIndex = 0;
for (let i = 0; i < indices.length; i++) {
if (indices[i] > maxIndex) maxIndex = indices[i];
}
arrayType = (maxIndex < 65536) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding(indices, arrayType));
} else {
throw new Error("Indices must be Uint16Array, Uint32Array, or Array of numbers");
}
this.indexFormat = arrayType;
this.indexBufferData = typedArray;
if (this.indexBuffer) {
this.indexBuffer.destroy();
}
this.indexBuffer = this.device.createBuffer({
size: this.indexBufferData.byteLength,
usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST,
});
this.device.queue.writeBuffer(
this.indexBuffer,
0,
this.indexBufferData.buffer,
this.indexBufferData.byteOffset,
this.indexBufferData.byteLength
);
this.indexCount = this.indexBufferData.length;
}
setAttribute( name, array ) {
if ( this.attributeArrays === undefined ) {
this.attributeArrays = new Map();
}
const { typedArray, arrayType } = this.convertToTypedArrayWithPadding( array );
let attributeSize = 3; // Default size (vec3)
// this is not proper
if ( name === "uv" ) {
attributeSize = 2;
} else if ( name === "position" || name === "normal" ) {
attributeSize = 3;
} else {
// You can add more attribute names and sizes here
}
this.attributeArrays.set( name, { typedArray: typedArray, size: attributeSize } );
}
getAllAttributeBuffers() {
if ( !this.attributeBuffers ) {
return new Map();
}
return this.attributeBuffers;
}
vertexBuffers( passEncoder ) {
if ( this.attributeArrays === undefined ) {
console.warn("No attribute arrays to bind.");
return;
}
if ( this.mergedAttributeBuffer === undefined ) {
const attributeNames = Array.from( this.attributeArrays.keys() );
const arrays = attributeNames.map( name => this.attributeArrays.get( name ).typedArray );
const attributeSizes = attributeNames.map( name => (name === "uv") ? 2 : 3 );
const vertexCount = arrays[0].length / attributeSizes[0];
const strideFloats = attributeSizes.reduce( (sum, size) => sum + size, 0 );
const mergedArray = new Float32Array( vertexCount * strideFloats );
for ( let i = 0; i < vertexCount; i++ ) {
let offset = i * strideFloats;
for ( let attrIndex = 0; attrIndex < arrays.length; attrIndex++ ) {
const arr = arrays[attrIndex];
const size = attributeSizes[attrIndex];
const baseIndex = i * size;
for ( let c = 0; c < size; c++ ) {
mergedArray[ offset++ ] = arr[ baseIndex + c ];
}
}
}
const byteLength = mergedArray.byteLength;
const buffer = this.device.createBuffer({
size : byteLength,
usage : GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation : true,
label : "mergedAttributes"
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( mergedArray );
buffer.unmap();
this.device.queue.writeBuffer(
buffer,
0,
mergedArray.buffer,
mergedArray.byteOffset,
mergedArray.byteLength
);
this.mergedAttributeBuffer = buffer;
// Store for pipeline setup
this.attributeStrideFloats = strideFloats;
this.attributeSizes = attributeSizes;
this.attributeNames = attributeNames;
}
// Bind the merged buffer at slot 0
passEncoder.setVertexBuffer( 0, this.mergedAttributeBuffer );
}
setCanvas( canvas ) {
this.canvas = canvas;
}
encodeRender( passEncoder, x, y = 0, z = 0 ) {
if ( !this.bindGroups || this.bindGroups.size === 0 ) {
this.createBindGroups();
}
passEncoder.setPipeline( this.pipeline );
this.vertexBuffers( passEncoder );
for ( const [ bindingGroupIndex, bindGroup ] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
if ( this.indexBuffer ) {
passEncoder.setIndexBuffer( this.indexBuffer, this.indexFormat );
passEncoder.drawIndexed( this.indexCount, y, z );
} else {
passEncoder.draw( x, y, z );
}
}
async renderToCanvas( x, y = 0, z = 0 ) {
// AUTO-REBUILD BIND GROUPS IF NEEDED
if ( !this.bindGroups || this.bindGroups.size === 0 ) {
this.createBindGroups();
}
const canvas = this.canvas;
const context = canvas.getContext( "webgpu" );
const format = navigator.gpu.getPreferredCanvasFormat();
const depthTexture = this.device.createTexture( {
size: [ this.canvas.width, this.canvas.height, 1 ],
sampleCount: this.sampleCount,
format: "depth24plus",
usage: GPUTextureUsage.RENDER_ATTACHMENT
} );
const commandEncoder = this.device.createCommandEncoder();
let renderPassDescriptor;
const colorLoadOp = this.clearOnRender ? "clear" : "load";
const depthLoadOp = this.clearOnRender ? "clear" : "load";
const clearColor = { r: 0.3, g: 0.3, b: 0.3, a: 1 };
if ( this.sampleCount === 1 ) {
renderPassDescriptor = {
colorAttachments: [ {
view: context.getCurrentTexture().createView(),
loadOp: colorLoadOp,
storeOp: "store",
clearValue: clearColor
} ],
depthStencilAttachment: {
view: this.depthTexture.createView(),
depthLoadOp: depthLoadOp,
depthStoreOp: "store",
depthClearValue: 1.0
}
};
} else {
renderPassDescriptor = {
colorAttachments: [ {
view: this.multisampledColorTexture.createView(),
resolveTarget: context.getCurrentTexture().createView(),
loadOp: colorLoadOp,
storeOp: "store",
clearValue: { r: 0, g: 0, b: 0, a: 1 }
} ],
depthStencilAttachment: {
view: this.multisampledDepthTexture.createView(),
depthLoadOp: depthLoadOp,
depthStoreOp: "store",
depthClearValue: 1.0
}
};
}
const passEncoder = commandEncoder.beginRenderPass( renderPassDescriptor );
this.encodeRender( passEncoder, x, y, z );
passEncoder.end();
this.device.queue.submit( [ commandEncoder.finish() ] );
}
async finalizeShader() {
//await this._createPipeline();
await this.createBindGroups();
}
async debugBuffer( bufferName, bufferType ) {
const buffer = this.getBuffer( bufferName );
const dataSize = buffer.size;
const debugReadBuffer = this.device.createBuffer( {
size: dataSize,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
} );
const debugEncoder = this.device.createCommandEncoder();
//console.log( "print buffer", buffer );
debugEncoder.copyBufferToBuffer(
buffer,
0,
debugReadBuffer,
0,
dataSize
);
this.device.queue.submit( [ debugEncoder.finish() ] );
await this.device.queue.onSubmittedWorkDone();
await debugReadBuffer.mapAsync( GPUMapMode.READ );
const copyArrayBuffer = debugReadBuffer.getMappedRange();
let result;
const bindingInfo = this.bindings.find( b => b.varName === bufferName );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
//console.log("wgslType", bindingInfo, wgslType, bufferName);
// === STRUCT SUPPORT ========================================================
if (this.structs && this.structs[wgslType]) {
const structInfo = this.structs[wgslType];
const view = new DataView(copyArrayBuffer);
const resultObj = {};
for (const field of structInfo.fields) {
const off = field.offset;
switch (field.type) {
case "u32":
resultObj[field.name] = view.getUint32(off, true);
break;
case "i32":
resultObj[field.name] = view.getInt32(off, true);
break;
case "f32":
resultObj[field.name] = view.getFloat32(off, true);
break;
default:
console.warn("Unhandled struct field type:", field.type);
resultObj[field.name] = null;
break;
}
}
debugReadBuffer.unmap();
return resultObj;
}
switch ( wgslType ) {
case "array<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "array<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "vec3<f32>":
break;
case "vec4<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "Tensor":
result = new Float32Array(copyArrayBuffer);
break;
case "array<vec4<f32>>":
case "array<Vector>":
case "array<array<f32, 64>>":
case "array<vec3<f32>>":
result = new Float32Array( copyArrayBuffer );
break;
case "vec3<i32>":
case "vec4<i32>":
result = new Int32Array( copyArrayBuffer );
break;
default:
console.error( "Unknown WGSL type in setVariable: " + wgslType, "Using Float32Array as fallback." );
result = new Float32Array( copyArrayBuffer );
}
// use wgslType here to determine appropriate TypedArray
}
const length = result.length;
var regularArray = new Array();
for (var i = 0; i < length; i++) {
regularArray.push(result[i]);
}
/*
console.log("Debugging array");
console.log("");
console.log("Shader: ", this.path);
console.log("Variable Name: ", bufferName);
console.log( "Buffer Size: ", this.getBuffer(bufferName).size );
console.log( "TypedArray Size: ", length );
*/
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
//console.log( "type: ", wgslType );
}
//console.log(regularArray);
debugReadBuffer.unmap();
return regularArray;
}
}