Files
WebGPU-FFT-Ocean-Demo/framework/WebGpu.js
2025-12-31 14:22:45 +01:00

2257 lines
46 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
if( typeof document != "undefined" ) {
document.shaders = new Array();
document.bufferMap = {};
}
export default class Shader {
sampleCount = 1;
device;
path;
wgslSource;
bindGroupLayout;
bindGroupLayouts = new Map();
vertexBuffersDescriptors = new Map();
textures = new Map();
pipeline;
bindGroups = new Map();
entryPoint = "main";
bindings = new Array();
buffers = new Map();
bufferUsages = new Map();
bufferSizes = new Map();
attributeBuffers ;
bufferUsages = new Map(); // store usage per buffer to recreate correctly
_externalBindGroupLayout = null;
hasLoaded = false;
stage = GPUShaderStage.COMPUTE;
isInWorker = typeof document === "undefined";
topologyValue = "triangle-strip";
binded = false;
clearOnRender = true;
textures = new Map();
samplers = new Map();
constructor( device, path = false ) {
if( !this.isInWorker ) {
document.shaders.push( this );
}
console.log("\n\n \n Creating New Shader", path, "\n\n");
console.log("device", typeof device);
this.device = device;
if( path ) {
this.path = path;
}
}
set topology(value) {
if (this.binded) {
console.error("Define topology before setup", this);
}
this.topologyValue = value;
}
get topology() {
return this.topologyValue;
}
_createAllBuffers() {
for (const b of this.bindings) {
if (this.buffers.has(b.varName)) continue;
let size;
if (b.type === "uniform") {
size = 256; // uniforms must be 256 bytes
}
else if (b.type === "read-only-storage" || b.type === "storage") {
size = 16; // minimum valid storage buffer size
}
else {
size = 16; // safe fallback
}
let usage;
switch (b.type) {
case "uniform":
usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
break;
case "read-only-storage":
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
case "storage":
default:
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
}
const buffer = this.device.createBuffer({
size,
usage,
label: b.varName
});
this.buffers.set(b.varName, buffer);
this.bufferUsages.set(b.varName, usage);
this.bufferSizes.set(b.varName, size);
}
}
_parseStructs(wgsl) {
this.structs = {};
const structRegex =
/struct\s+(\w+)\s*\{([^}]+)\}/g;
let match;
while ((match = structRegex.exec(wgsl)) !== null) {
const structName = match[1];
const body = match[2].trim();
const fields = [];
const fieldRegex =
/(\w+)\s*:\s*([\w<>\[\]]+)\s*,?/g;
let f;
while ((f = fieldRegex.exec(body)) !== null) {
const fieldName = f[1];
const fieldType = f[2];
fields.push({
name: fieldName,
type: fieldType
});
}
this.structs[structName] = {
name: structName,
fields: fields,
size: 0 // computed later
};
}
for (const name of Object.keys(this.structs)) {
this._computeStructSize(name);
}
console.log("structs reflection created:", this.structs);
}
_computeStructSize(structName) {
const struct = this.structs[structName];
if (!struct) return 0;
let offset = 0;
for (const f of struct.fields) {
const align = 4;
const size = 4;
// align field offset
offset = Math.ceil(offset / align) * align;
f.offset = offset;
f.size = size;
f.align = align;
offset += size;
}
struct.size = offset;
return offset;
}
_computeTypeSize(wgslType) {
if (wgslType === "f32" || wgslType === "i32" || wgslType === "u32") {
return 4;
}
// vecN<T>
const vecMatch = wgslType.match(/vec(\d+)<(\w+)>/);
if (vecMatch) {
const n = Number(vecMatch[1]);
const elemType = vecMatch[2];
return n * this._computeTypeSize(elemType);
}
// matCxR<T> (e.g. mat4x4<f32>)
const matMatch = wgslType.match(/mat(\d+)x(\d+)<(\w+)>/);
if (matMatch) {
const c = Number(matMatch[1]);
const r = Number(matMatch[2]);
const elemType = matMatch[3];
return c * r * this._computeTypeSize(elemType);
}
// array<T, N>
const arrFix = wgslType.match(/array<(\w+),\s*(\d+)>/);
if (arrFix) {
const element = arrFix[1];
const count = Number(arrFix[2]);
return this._computeTypeSize(element) * count;
}
// array<T> runtime sized
const arrDyn = wgslType.match(/array<(\w+)>/);
if (arrDyn) {
return 4;
}
// struct
if (this.structs[wgslType]) {
return this.structs[wgslType].size;
}
return 4;
}
_computeBindingSizes() {
for (const b of this.bindings) {
const t = b.varType;
// -----------------------------------------
// 1. Struct?
// -----------------------------------------
if (this.structs[t]) {
b.size = this.structs[t].size;
continue;
}
// -----------------------------------------
// 4. Texture / sampler (no buffer)
// -----------------------------------------
if (b.type === "texture" || b.type === "sampler") {
b.size = 0;
continue;
}
// -----------------------------------------
// 5. Use computed WGSL type size (scalars, vectors, matrices, arrays)
// -----------------------------------------
b.size = this._computeTypeSize(t);
}
}
_createAllBuffers() {
for (const b of this.bindings) {
if (this.buffers.has(b.varName))
continue;
let size = b.size;
if (size < 4) size = 4;
let usage;
switch (b.type) {
case "uniform":
usage = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
break;
case "read-only-storage":
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
break;
default:
usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
const buffer = this.device.createBuffer({
size,
usage,
label: b.varName
});
this.buffers.set(b.varName, buffer);
this.bufferSizes.set(b.varName, size);
this.bufferUsages.set(b.varName, usage);
}
}
_createDefaultTexturesAndSamplers() {
for ( const b of this.bindings ) {
if ( b.type === "texture" && !this.textures.has( b.varName ) ) {
// Detect if the type is a texture array
const isTextureArray = b.varType.startsWith("texture_2d_array");
// Create 1x1 texture with 1 layer (depthOrArrayLayers = 1)
const defaultTexture = this.device.createTexture({
size: isTextureArray ? [1, 1, 2] : [1, 1, 1],
format: "rgba8unorm",
usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST,
label: b.varName + "_defaultTexture"
});
console.log("create texture ", isTextureArray, isTextureArray ? "2d-array" : "2d");
// Create texture view with correct dimension
const defaultTextureView = defaultTexture.createView({
dimension: isTextureArray ? "2d-array" : "2d"
});
console.log("create default texture?", b, defaultTextureView, isTextureArray ? "2d-array" : "2d");
this.textures.set(b.varName, defaultTextureView);
} else if ( b.type === "sampler" && !this.samplers.has( b.varName ) ) {
// Create default sampler
const defaultSampler = this.device.createSampler( {
magFilter: "linear",
minFilter: "linear",
label: b.varName + "_defaultSampler"
} );
this.samplers.set( b.varName, defaultSampler );
}
}
}
createBindGroups() {
this.bindGroups = new Map();
for ( const [ group, layout ] of this.bindGroupLayouts.entries() ) {
const bindings = this.bindings.filter( b => b.group === group );
const entries = new Array( bindings.length );
for ( let i = 0; i < bindings.length; i++ ) {
const b = bindings[ i ];
let resource;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
const gpuBuffer = this.buffers.get( b.varName );
if ( !gpuBuffer ) {
throw new Error( `Buffer missing for ${b.varName} when creating bind group.` );
}
resource = { buffer: gpuBuffer };
} else if ( b.type === "texture" ) {
const gpuTextureView = this.textures.get( b.varName );
//console.log("this.textures", this.textures);
if ( !gpuTextureView ) {
console.warn( `Texture missing for ${b.varName} when creating bind group.` );
} else {
//console.log( `Binding texture ${b.varName} with dimension:`, gpuTextureView.dimension );
}
resource = gpuTextureView;
} else if ( b.type === "sampler" ) {
const gpuSampler = this.samplers.get( b.varName );
if ( !gpuSampler ) {
//throw new Error( `Sampler missing for ${b.varName} when creating bind group.` );
}
resource = gpuSampler;
} else {
throw new Error( `Unknown binding type '${b.type}' for ${b.varName}` );
}
entries[ i ] = {
binding: b.binding,
resource: resource
};
}
const bindGroup = this.device.createBindGroup( {
layout: layout,
entries: entries
} );
this.bindGroups.set( group, bindGroup );
}
}
extractEntryPoints( wgslCode ) {
const entryPoints = [];
// Regex to match lines like: @vertex fn functionName(...) { ... }
const regex =
/@(\w+)(?:\s*@[\w_]+(?:\([^)]*\))?)*\s*fn\s+([\w_]+)\s*\(/gm;
let match;
while ((match = regex.exec(wgslCode)) !== null) {
const stage = match[1];
const name = match[2];
let type;
switch (stage) {
case "vertex":
type = GPUShaderStage.VERTEX;
break;
case "fragment":
type = GPUShaderStage.FRAGMENT;
break;
case "compute":
type = GPUShaderStage.COMPUTE;
break;
default:
type = "Unknown";
}
entryPoints.push({ name: name, type: type });
}
return entryPoints;
}
async setup( path ) {
this.path = path;
this.wgslSource = await this._loadShaderSource( this.path );
// 2. Parse metadata from WGSL
this._parseVertexAttributes(this.wgslSource);
this._parseBindings(this.wgslSource);
this._parseStructs(this.wgslSource);
// 3. Compute struct sizes BEFORE creating buffers
this._computeBindingSizes(); // if you implemented this
// ----------------------------------------------------------
// 4. CREATE BUFFERS BEFORE EVERYTHING ELSE
// ----------------------------------------------------------
this._createAllBuffers();
//console.log(this.wgslSource);
var entryPoints = this.extractEntryPoints( this.wgslSource );
console.log(entryPoints);
for (var i = 0; i < entryPoints.length; i++) {
var entryPoint = entryPoints[i];
await this.addStage( entryPoint.name, entryPoint.type );
}
//await this.addStage("computeMain", 4 );
//console.log( "load shader", this.wgslSource, entryPoints )
}
async _loadShaderSource( pathName ) {
const response = await fetch( pathName );
if ( !response.ok ) throw new Error( `Failed to load shader: ${ pathName }` );
return await response.text();
}
setVertexBuffer( name, dataArray ) {
const buffer = this.device.createBuffer({
size: dataArray.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation: true
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( dataArray );
buffer.unmap();
this.buffers.set( name, buffer );
}
extractVertexFunctionParameters(wgsl) {
let vertexFnIndex = wgsl.indexOf('@vertex');
if (vertexFnIndex === -1) return '';
let fnIndex = wgsl.indexOf('fn', vertexFnIndex);
if (fnIndex === -1) return '';
let parenStart = wgsl.indexOf('(', fnIndex);
if (parenStart === -1) return '';
let parenCount = 1;
let i = parenStart + 1;
while (i < wgsl.length && parenCount > 0) {
const c = wgsl[i];
if (c === '(') parenCount++;
else if (c === ')') parenCount--;
i++;
}
if (parenCount !== 0) {
// unbalanced parentheses
return '';
}
// Extract parameter substring between parenStart+1 and i-1
let paramsStr = wgsl.substring(parenStart + 1, i - 1);
return paramsStr.trim();
}
extractVertexParamsLocations(wgsl) {
const paramsStr = this.extractVertexFunctionParameters(wgsl)
const paramRegex = /@location\s*\(\s*(\d+)\s*\)\s+(\w+)\s*:\s*([\w<>\d]+(\s*<[^>]+>)?)/g;
const results = new Array();
let paramMatch;
while ((paramMatch = paramRegex.exec(paramsStr)) !== null) {
const location = Number(paramMatch[1]);
const varName = paramMatch[2];
const type = paramMatch[3].trim();
results.push({ location, varName, type });
}
return results;
}
getFormatSize( format ) {
switch ( format ) {
case "float32": return 4;
case "float32x2": return 8;
case "float32x3": return 12;
case "float32x4": return 16;
case "uint32": return 4;
case "sint32": return 4;
default: throw new Error("Unsupported format size for: " + format);
}
}
mapTypeToVertexFormat( type ) {
// Map WGSL types to GPUVertexFormat strings (example subset)
switch ( type ) {
case "f32": return "float32";
case "vec2<f32>": return "float32x2";
case "vec3<f32>": return "float32x3";
case "vec4<f32>": return "float32x4";
case "u32": return "uint32";
case "i32": return "sint32";
default: throw new Error("Unsupported vertex attribute type: " + type);
}
}
_parseVertexAttributes( wgsl ) {
const locations = this.extractVertexParamsLocations( wgsl );
if (locations.length === 0) {
// No vertex attributes = no vertex buffer descriptors needed
return;
}
let offset = 0;
const attributes = new Array();
for ( let attr of locations ) {
const format = this.mapTypeToVertexFormat( attr.type );
attributes.push({
shaderLocation : attr.location,
offset : offset,
format : format
});
offset += this.getFormatSize( format );
}
const vertexBufferDescriptor = {
arrayStride: offset,
attributes: attributes,
stepMode: "vertex"
};
console.log("vertexBufferDescriptor", vertexBufferDescriptor);
this.vertexBuffersDescriptors.set("vertexBuffer0", vertexBufferDescriptor);
console.log("vertex buffer descriptors updated:", this.vertexBuffersDescriptors);
}
_parseBindings( wgsl ) {
console.log("parse bindings");
this.bindings = [];
// Make the `<...>` part optional by wrapping it in a non-capturing group with `?`
const regex = /@group\((\d+)\)\s+@binding\((\d+)\)\s+var(?:<(\w+)(?:,\s*(\w+))?>)?\s+(\w+)\s*:\s*([^;]+);/g;
let match;
while ((match = regex.exec(wgsl)) !== null) {
const group = Number(match[1]);
const binding = Number(match[2]);
const storageClass = match[3]; // e.g. "uniform", "storage", or undefined for textures/samplers
const access = match[4]; // e.g. "read", "write", or undefined
const varName = match[5];
const varType = match[6]; // WGSL type string like "mat4x4<f32>", "texture_2d<f32>", "sampler", etc.
let type;
if (storageClass === "storage") {
type = access === "read" ? "read-only-storage" : "storage";
} else if (storageClass === "uniform") {
type = "uniform";
} else if (varType.startsWith("texture_")) {
type = "texture";
} else if (varType === "sampler" || varType === "sampler_comparison") {
type = "sampler";
} else {
type = "uniform"; // fallback
}
this.bindings.push({ group, binding, type, varName, varType });
}
console.log("bindings", this.bindings);
this.bindings.sort((a, b) => a.binding - b.binding);
}
_createBindGroupLayoutsFromBindings() {
const groups = new Map();
for ( let i = 0; i < this.bindings.length; i++ ) {
const b = this.bindings[ i ];
if ( !groups.has( b.group ) ) groups.set( b.group, [] );
groups.get( b.group ).push( b );
}
const layouts = new Map();
for ( const groupEntry of groups.entries() ) {
const group = groupEntry[ 0 ];
const bindings = groupEntry[ 1 ];
let visibility;
if ( this.stage === GPUShaderStage.COMPUTE ) {
visibility = GPUShaderStage.COMPUTE;
} else {
visibility = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
const entries = new Array( bindings.length );
for ( let j = 0; j < bindings.length; j++ ) {
const b = bindings[ j ];
let entry;
if ( b.type === "uniform" || b.type === "read-only-storage" || b.type === "storage" ) {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: b.type }
};
} else if ( b.type === "texture" ) {
let viewDimension = "2d"; // default
if ( b.varType.startsWith("texture_2d_array") ) {
viewDimension = "2d-array";
} else if ( b.varType.startsWith("texture_cube") ) {
viewDimension = "cube";
} else if ( b.varType.startsWith("texture_3d") ) {
viewDimension = "3d";
}
console.log("viewDimension", viewDimension);
entry = {
binding: b.binding,
visibility: visibility,
texture: {
sampleType: "float",
viewDimension: viewDimension,
multisampled: false
}
};
} else if ( b.type === "sampler" ) {
entry = {
binding: b.binding,
visibility: visibility,
sampler: { type: "filtering" }
};
} else {
entry = {
binding: b.binding,
visibility: visibility,
buffer: { type: "uniform" }
};
}
entries[ j ] = entry;
}
console.log("createBindGroupLayout ", entries);
const layout = this.device.createBindGroupLayout( { entries } );
layouts.set( group, layout );
}
return layouts;
}
async _createPipeline() {
const layoutsArray = this.bindGroupLayouts && this.bindGroupLayouts.size > 0
? Array.from( this.bindGroupLayouts.values() )
: [ this.bindGroupLayout ];
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path,
} );
const info = await shaderModule.compilationInfo();
if (info.messages.length > 0) {
for (const msg of info.messages) {
console[msg.type === "error" ? "error" : "warn"](
`[${msg.lineNum}:${msg.linePos}] ${msg.message}`
);
}
}
//console.log("this.bindGroupLayouts.values()", this.bindGroupLayouts.values());
this.pipeline = this.device.createComputePipeline({
layout: this.device.createPipelineLayout({
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
}),
compute: {
module: shaderModule,
entryPoint: this.entryPoint,
},
} );
}
setBindingGroupLayout( bindGroupLayout ) {
this._externalBindGroupLayout = bindGroupLayout;
}
getBindingGroupLayout( bindingGroupIndex = 0 ) {
if ( this.bindGroupLayouts ) {
return this.bindGroupLayouts.get( bindingGroupIndex );
}
return this.bindGroupLayout;
}
getTypedArrayByVariableName( name, flatData ) {
let typedArray;
const bindingInfo = this.bindings.find( b => b.varName === name );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
switch( wgslType ) {
case "vec3<f32>":
// flatData.push( 0 );
break;
}
switch ( wgslType ) {
case "f32":
case "vec2<f32>":
case "vec3<f32>":
case "vec4<f32>":
case "array<f32>":
typedArray = new Float32Array( flatData );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
case "array<u32>":
typedArray = new Uint32Array( flatData );
break;
case "i32":
case "vec2<i32>":
case "vec3<i32>":
case "vec4<i32>":
case "array<i32>":
typedArray = new Int32Array( flatData );
break;
default:
// find struct and then the datatype
typedArray = new Float32Array( flatData );
//console.error( "Unknown WGSL type in setVariable: " + wgslType );
}
//console.log("wgslType", wgslType);
}
return typedArray;
}
createTextureFromData(width, height, data = null) {
// 1. Infer format based on TypedArray
let format;
let bytesPerComponent;
if (data instanceof Float32Array) {
format = "rgba32float";
bytesPerComponent = 4;
}
else if (data instanceof Uint8Array) {
format = "rgba8unorm";
bytesPerComponent = 1;
}
else if (data instanceof Uint32Array) {
format = "rgba32uint";
bytesPerComponent = 4;
}
else if (data instanceof Int32Array) {
format = "rgba32sint";
bytesPerComponent = 4;
}
else {
throw new Error("Unsupported texture data type");
}
// 2. Compute expected size
const channels = 4;
const expected = width * height * channels;
if (data && data.length !== expected) {
throw new Error(
`Texture data length mismatch. Expected ${expected}, got ${data.length}`
);
}
// 3. Create texture
const texture = this.device.createTexture({
size: { width, height, depthOrArrayLayers: 1 },
format,
usage:
GPUTextureUsage.TEXTURE_BINDING |
GPUTextureUsage.STORAGE_BINDING |
GPUTextureUsage.COPY_DST
});
// 4. Upload if data provided
if (data) {
this.device.queue.writeTexture(
{
texture: texture
},
data,
{
bytesPerRow: width * channels * bytesPerComponent
},
{
width,
height,
depthOrArrayLayers: 1
}
);
}
return texture;
}
setVariable( name, dataObject ) {
if (dataObject instanceof GPUBuffer) {
// Direct buffer binding do NOT recreate, resize, or write.
this.setBuffer( name, dataObject );
return;
}
// TEXTURES -----------------------------------------------------
if (dataObject instanceof GPUTexture) {
this.textures.set(name, dataObject.createView());
this._invalidateBindGroups();
return;
}
// SAMPLERS -----------------------------------------------------
if (dataObject instanceof GPUSampler) {
this.samplers.set(name, dataObject);
this._invalidateBindGroups();
return;
}
// STRUCT CHECK -------------------------------------------------
const binding = this.bindings.find(b => b.varName === name);
const structInfo = binding ? this.structs[binding.varType] : undefined;
let typed;
// --------------------------------------------------------------
// 1. WGSL STRUCT → Uint32Array (using struct reflection)
// --------------------------------------------------------------
if (structInfo) {
const out = new Uint32Array(structInfo.size / 4); // aligned struct size in bytes
for (let i = 0; i < structInfo.fields.length; i++) {
const f = structInfo.fields[i];
let v = dataObject[f.name];
if (v === undefined) {
v = 0;
}
// only scalar u32/f32 fields for now
out[f.offset / 4] = v >>> 0;
}
typed = out;
}
// --------------------------------------------------------------
// 2. Already typed arrays
// --------------------------------------------------------------
else if (
dataObject instanceof Float32Array ||
dataObject instanceof Uint32Array ||
dataObject instanceof Int32Array
) {
typed = dataObject;
}
// --------------------------------------------------------------
// 3. Number or JS array → detect type
// --------------------------------------------------------------
else if (typeof dataObject === "number" || Array.isArray(dataObject)) {
const flat = this.flattenStructure(dataObject);
typed = this.getTypedArrayByVariableName(name, flat);
}
// --------------------------------------------------------------
// 4. Raw ArrayBuffer
// --------------------------------------------------------------
else if (dataObject instanceof ArrayBuffer) {
typed = new Float32Array(dataObject);
}
// --------------------------------------------------------------
// 5. Fallback
// --------------------------------------------------------------
else {
const flat = this.flattenStructure(dataObject);
typed = new Float32Array(flat);
}
// BUFFER CREATION / RESIZE -------------------------------------
const bytes = typed.byteLength;
const oldBuffer = this.buffers.get(name);
const oldUsage = this.bufferUsages.get(name);
const oldSize = oldBuffer ? oldBuffer.size : 0;
if (typeof oldUsage !== "number") {
throw new Error("Invalid buffer usage for '" + name + "'");
}
let buffer = oldBuffer;
// Resize if needed
if (!buffer || oldSize !== bytes) {
buffer = this.device.createBuffer({
label: name,
size: bytes,
usage: oldUsage,
mappedAtCreation: false
});
this.buffers.set(name, buffer);
this.bufferSizes.set(name, bytes);
this._invalidateBindGroups(); // <— IMPORTANT FIX
}
// BUFFER UPLOAD -------------------------------------------------
this.device.queue.writeBuffer(
buffer,
0,
typed.buffer,
typed.byteOffset,
typed.byteLength
);
}
_invalidateBindGroups() {
this.bindGroups = false;
}
computeUsage(name) {
const binding = this.bindings.find(b => b.varName === name);
if (!binding) {
// No binding found → use safe fallback
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
switch (binding.type) {
case "uniform":
return GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
case "read-only-storage":
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ;
case "storage":
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
default:
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
}
}
flattenStructure( input, output = [] ) {
if ( Array.isArray( input ) ) {
for ( const element of input ) {
this.flattenStructure( element, output );
}
} else if ( typeof input === "object" && input !== null ) {
for ( const key of Object.keys( input ) ) {
this.flattenStructure( input[ key ], output );
}
} else if ( typeof input === "number" ) {
output.push( input );
} else if ( typeof input === "boolean" ) {
output.push( input ? 1 : 0 );
} else {
throw new Error( "Unsupported data type in shader variable: " + typeof input );
}
return output;
}
copyBuffersFromShader( sourceShader ) {
for ( const [ name, buffer ] of sourceShader.buffers.entries() ) {
this.buffers.set( name, buffer );
}
}
updateUniforms( data ) {
const arrayBuffer = new ArrayBuffer(12); // 3 x u32 = 3 * 4 bytes
const dataView = new DataView(arrayBuffer);
dataView.setUint32(0, data.k, true); // offset 0
dataView.setUint32(4, data.j, true); // offset 4
dataView.setUint32(8, data.totalCount, true); // offset 8
this.device.queue.writeBuffer(
this.uniformBuffer, // your GPUBuffer holding uniforms
0, // offset in buffer
arrayBuffer
);
}
createCommandEncoder( x, y, z ) {
// AUTO-REBUILD BIND GROUPS IF NEEDED
if (!this.bindGroups || this.bindGroups.size === 0) {
this.createBindGroups();
}
const commandEncoder = this.device.createCommandEncoder({
label: this.path
});
const passEncoder = commandEncoder.beginComputePass({
label: this.path
});
if( !this.pipeline ) {
console.log("Pipeline missing, Maybe you forget to call addStage( ) ", this);
}
passEncoder.setPipeline( this.pipeline );
for ( const [bindingGroupIndex, bindGroup] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
passEncoder.dispatchWorkgroups( x, y, z );
passEncoder.end();
return commandEncoder;
}
async execute( x = 1, y = 1, z = 1 ) {
const commandEncoder = this.createCommandEncoder( x, y, z );
const commandBuffer = commandEncoder.finish();
this.device.queue.submit( [ commandBuffer ] );
await this.device.queue.onSubmittedWorkDone(); // critical!
}
registerBuffer( buffer, name, type ) {
if( !this.isInWorker ) {
if( !document.bufferMap[name] ) {
document.bufferMap[name] = new Array();
}
var bufferHistoryPoint = { buffer: buffer, shader: this, label : buffer.label, type: type };
document.bufferMap[name].push( bufferHistoryPoint );
}
}
getBuffer( name ) {
// If no bind groups yet, create them now
if (!this.bindGroups || this.bindGroups.size === 0) {
this.createBindGroups();
}
var buffer = this.buffers.get( name );
this.registerBuffer( buffer, name, "get" );
return buffer;
}
async setBuffer( name, buffer ) {
this.registerBuffer( buffer, name, "set" );
if( !this.isInWorker ) {
//var bufferOwner = document.bufferOwners.get( buffer );
//console.log("setBuffer Origin:", bufferOwner, " this shader", this);
}
const bindingInfo = this.bindings.find( b => b.varName === name );
if( !bindingInfo ) {
console.error("No binding to buffer ", name, " in shader ", this );
}
//console.log("set buffer", name, buffer);
this.buffers.set(name, buffer);
await this.createBindGroups(); // always refresh bindings
}
async addStage( entryPoint, stage ) {
if ( !this.path ) {
throw new Error("Shader path is not set");
}
this.entryPoint = entryPoint;
if( stage != GPUShaderStage.COMPUTE ) {
this.stage = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT;
}
if ( !this.binded ) {
if( stage != GPUShaderStage.COMPUTE ) {
if( this.sampleCount == 1 ) {
this.depthTexture = this.device.createTexture({
size: {
width: this.canvas.width,
height: this.canvas.height,
depthOrArrayLayers: 1
},
sampleCount: 1, // Must be 1 for non-multisampled rendering
format: "depth24plus", // Or "depth24plus-stencil8" if stencil is needed
usage: GPUTextureUsage.RENDER_ATTACHMENT
});
} else {
this.multisampledColorTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount,
format: "bgra8unorm",
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
this.multisampledDepthTexture = this.device.createTexture({
size: [this.canvas.width, this.canvas.height],
sampleCount: this.sampleCount, // must match multisample count of color texture and pipeline
format: "depth24plus", // or "depth24plus-stencil8" if you need stencil
usage: GPUTextureUsage.RENDER_ATTACHMENT,
});
}
}
this.wgslSource = await this._loadShaderSource( this.path );
//this._parseVertexAttributes( this.wgslSource );
//this._parseBindings( this.wgslSource );
//this._parseStructs( this.wgslSource );
if ( this._externalBindGroupLayout ) {
this.bindGroupLayouts.set( 0, this._externalBindGroupLayout );
} else {
this.bindGroupLayouts = this._createBindGroupLayoutsFromBindings();
}
this._createDefaultTexturesAndSamplers();
await this.createBindGroups();
this.binded = true;
}
const shaderModule = this.device.createShaderModule( {
code: this.wgslSource,
label: this.path
} );
if ( stage === GPUShaderStage.COMPUTE ) {
//console.log("this.bindGroupLayouts", this.bindGroupLayouts);
//console.log("create compute shader pipeline.");
this.computeStage = { module: shaderModule, entryPoint: this.entryPoint };
this.pipeline = this.device.createComputePipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
compute: this.computeStage
} );
await this.finalizeShader();
} else if ( stage === GPUShaderStage.VERTEX ) {
//this.vertexStage = { module: shaderModule, entryPoint: this.entryPoint };
console.log("Vertex buffers descriptors:", Array.from(this.vertexBuffersDescriptors.values()));
const vertexBuffers = this.vertexBuffersDescriptors.size > 0
? [...this.vertexBuffersDescriptors.values()]
: [];
this.vertexStage = {
module: shaderModule,
entryPoint: this.entryPoint,
buffers: vertexBuffers
};
} else if ( stage === GPUShaderStage.FRAGMENT ) {
this.fragmentStage = { module: shaderModule, entryPoint: this.entryPoint };
}
// If both vertex and fragment stages exist, create render pipeline
if ( this.vertexStage && this.fragmentStage ) {
this._createAllBuffers();
console.log( this.device );
console.log("create fragment and vertex shader pipeline.");
console.log("setup createRenderPipeline", this.vertexStage);
console.log("this.device.createPipelineLayout", this.bindGroupLayouts.get(0) );
this.pipeline = this.device.createRenderPipeline( {
layout: this.device.createPipelineLayout( {
bindGroupLayouts: Array.from( this.bindGroupLayouts.values() )
} ),
vertex: this.vertexStage,
fragment: {
...this.fragmentStage,
targets: [
{
format: "bgra8unorm",
// No blending here for opaque rendering
blend: undefined,
writeMask: GPUColorWrite.ALL
}
] // Specify the color target format here
},
multisample: {
count: this.sampleCount, // number of samples per pixel, typical values: 4 or 8
},
primitive: {
topology: this.topologyValue,
frontFace: 'ccw' ,
cullMode: 'none',
},
depthStencil: {
format: "depth24plus", // <-- Add this to match render pass depthStencil
depthWriteEnabled: true,
depthCompare: "less",
},
} );
this._createAllBuffers();
await this.finalizeShader();
}
}
convertToTypedArrayWithPadding(array, preferredType = null) {
const isTypedArray = (
array instanceof Uint16Array ||
array instanceof Uint32Array ||
array instanceof Float32Array
);
let typedArray;
let arrayType = preferredType;
if (isTypedArray) {
if (preferredType === null) {
if (array instanceof Uint16Array) arrayType = "uint16";
else if (array instanceof Uint32Array) arrayType = "uint32";
else if (array instanceof Float32Array) arrayType = "float32";
}
const remainder = array.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = array.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(array);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
} else {
typedArray = array;
}
} else if (Array.isArray(array)) {
if (preferredType === null) {
// Default to Float32Array if no preferred type
arrayType = "float32";
} else {
arrayType = preferredType;
}
if (arrayType === "uint16") {
typedArray = new Uint16Array(array);
} else if (arrayType === "uint32") {
typedArray = new Uint32Array(array);
} else if (arrayType === "float32") {
typedArray = new Float32Array(array);
} else {
throw new Error("Unsupported preferred type");
}
const remainder = typedArray.byteLength % 4;
if (remainder !== 0) {
const elementSize = (arrayType === "uint16") ? 2 : 4;
const paddedByteLength = typedArray.byteLength + (4 - remainder);
const paddedElementCount = paddedByteLength / elementSize;
if (arrayType === "uint16") {
const paddedArray = new Uint16Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "uint32") {
const paddedArray = new Uint32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else if (arrayType === "float32") {
const paddedArray = new Float32Array(paddedElementCount);
paddedArray.set(typedArray);
typedArray = paddedArray;
} else {
throw new Error("Unsupported typed array type for padding");
}
}
} else {
throw new Error("Input must be a TypedArray or Array of numbers");
}
return { typedArray, arrayType };
}
setIndices( indices ) {
console.log("indices", indices);
const isTypedArray = indices instanceof Uint16Array || indices instanceof Uint32Array;
let typedArray;
let arrayType;
console.log("isTypedArray", isTypedArray);
if ( isTypedArray ) {
arrayType = (indices instanceof Uint16Array) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding( indices, arrayType ) );
} else if (Array.isArray(indices)) {
let maxIndex = 0;
for (let i = 0; i < indices.length; i++) {
if (indices[i] > maxIndex) maxIndex = indices[i];
}
arrayType = (maxIndex < 65536) ? "uint16" : "uint32";
({ typedArray } = this.convertToTypedArrayWithPadding(indices, arrayType));
} else {
throw new Error("Indices must be Uint16Array, Uint32Array, or Array of numbers");
}
this.indexFormat = arrayType;
this.indexBufferData = typedArray;
if (this.indexBuffer) {
this.indexBuffer.destroy();
}
this.indexBuffer = this.device.createBuffer({
size: this.indexBufferData.byteLength,
usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST,
});
this.device.queue.writeBuffer(
this.indexBuffer,
0,
this.indexBufferData.buffer,
this.indexBufferData.byteOffset,
this.indexBufferData.byteLength
);
this.indexCount = this.indexBufferData.length;
}
setAttribute( name, array ) {
if ( this.attributeArrays === undefined ) {
this.attributeArrays = new Map();
}
const { typedArray, arrayType } = this.convertToTypedArrayWithPadding( array );
let attributeSize = 3; // Default size (vec3)
// this is not proper
if ( name === "uv" ) {
attributeSize = 2;
} else if ( name === "position" || name === "normal" ) {
attributeSize = 3;
} else {
// You can add more attribute names and sizes here
}
this.attributeArrays.set( name, { typedArray: typedArray, size: attributeSize } );
}
getAllAttributeBuffers() {
if ( !this.attributeBuffers ) {
return new Map();
}
return this.attributeBuffers;
}
vertexBuffers( passEncoder ) {
if ( this.attributeArrays === undefined ) {
console.warn("No attribute arrays to bind.");
return;
}
if ( this.mergedAttributeBuffer === undefined ) {
const attributeNames = Array.from( this.attributeArrays.keys() );
const arrays = attributeNames.map( name => this.attributeArrays.get( name ).typedArray );
const attributeSizes = attributeNames.map( name => (name === "uv") ? 2 : 3 );
const vertexCount = arrays[0].length / attributeSizes[0];
const strideFloats = attributeSizes.reduce( (sum, size) => sum + size, 0 );
const mergedArray = new Float32Array( vertexCount * strideFloats );
for ( let i = 0; i < vertexCount; i++ ) {
let offset = i * strideFloats;
for ( let attrIndex = 0; attrIndex < arrays.length; attrIndex++ ) {
const arr = arrays[attrIndex];
const size = attributeSizes[attrIndex];
const baseIndex = i * size;
for ( let c = 0; c < size; c++ ) {
mergedArray[ offset++ ] = arr[ baseIndex + c ];
}
}
}
const byteLength = mergedArray.byteLength;
const buffer = this.device.createBuffer({
size : byteLength,
usage : GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
mappedAtCreation : true,
label : "mergedAttributes"
});
const mapping = new Float32Array( buffer.getMappedRange() );
mapping.set( mergedArray );
buffer.unmap();
this.device.queue.writeBuffer(
buffer,
0,
mergedArray.buffer,
mergedArray.byteOffset,
mergedArray.byteLength
);
this.mergedAttributeBuffer = buffer;
// Store for pipeline setup
this.attributeStrideFloats = strideFloats;
this.attributeSizes = attributeSizes;
this.attributeNames = attributeNames;
}
// Bind the merged buffer at slot 0
passEncoder.setVertexBuffer( 0, this.mergedAttributeBuffer );
}
setCanvas( canvas ) {
this.canvas = canvas;
}
encodeRender( passEncoder, x, y = 0, z = 0 ) {
if ( !this.bindGroups || this.bindGroups.size === 0 ) {
this.createBindGroups();
}
passEncoder.setPipeline( this.pipeline );
this.vertexBuffers( passEncoder );
for ( const [ bindingGroupIndex, bindGroup ] of this.bindGroups.entries() ) {
passEncoder.setBindGroup( bindingGroupIndex, bindGroup );
}
if ( this.indexBuffer ) {
passEncoder.setIndexBuffer( this.indexBuffer, this.indexFormat );
passEncoder.drawIndexed( this.indexCount, y, z );
} else {
passEncoder.draw( x, y, z );
}
}
async renderToCanvas( x, y = 0, z = 0 ) {
// AUTO-REBUILD BIND GROUPS IF NEEDED
if ( !this.bindGroups || this.bindGroups.size === 0 ) {
this.createBindGroups();
}
const canvas = this.canvas;
const context = canvas.getContext( "webgpu" );
const format = navigator.gpu.getPreferredCanvasFormat();
const depthTexture = this.device.createTexture( {
size: [ this.canvas.width, this.canvas.height, 1 ],
sampleCount: this.sampleCount,
format: "depth24plus",
usage: GPUTextureUsage.RENDER_ATTACHMENT
} );
const commandEncoder = this.device.createCommandEncoder();
let renderPassDescriptor;
const colorLoadOp = this.clearOnRender ? "clear" : "load";
const depthLoadOp = this.clearOnRender ? "clear" : "load";
const clearColor = { r: 0.3, g: 0.3, b: 0.3, a: 1 };
if ( this.sampleCount === 1 ) {
renderPassDescriptor = {
colorAttachments: [ {
view: context.getCurrentTexture().createView(),
loadOp: colorLoadOp,
storeOp: "store",
clearValue: clearColor
} ],
depthStencilAttachment: {
view: this.depthTexture.createView(),
depthLoadOp: depthLoadOp,
depthStoreOp: "store",
depthClearValue: 1.0
}
};
} else {
renderPassDescriptor = {
colorAttachments: [ {
view: this.multisampledColorTexture.createView(),
resolveTarget: context.getCurrentTexture().createView(),
loadOp: colorLoadOp,
storeOp: "store",
clearValue: { r: 0, g: 0, b: 0, a: 1 }
} ],
depthStencilAttachment: {
view: this.multisampledDepthTexture.createView(),
depthLoadOp: depthLoadOp,
depthStoreOp: "store",
depthClearValue: 1.0
}
};
}
const passEncoder = commandEncoder.beginRenderPass( renderPassDescriptor );
this.encodeRender( passEncoder, x, y, z );
passEncoder.end();
this.device.queue.submit( [ commandEncoder.finish() ] );
}
async finalizeShader() {
//await this._createPipeline();
await this.createBindGroups();
}
async debugBuffer( bufferName, bufferType ) {
const buffer = this.getBuffer( bufferName );
const dataSize = buffer.size;
const debugReadBuffer = this.device.createBuffer( {
size: dataSize,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
} );
const debugEncoder = this.device.createCommandEncoder();
//console.log( "print buffer", buffer );
debugEncoder.copyBufferToBuffer(
buffer,
0,
debugReadBuffer,
0,
dataSize
);
this.device.queue.submit( [ debugEncoder.finish() ] );
await this.device.queue.onSubmittedWorkDone();
await debugReadBuffer.mapAsync( GPUMapMode.READ );
const copyArrayBuffer = debugReadBuffer.getMappedRange();
let result;
const bindingInfo = this.bindings.find( b => b.varName === bufferName );
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
//console.log("wgslType", bindingInfo, wgslType, bufferName);
// === STRUCT SUPPORT ========================================================
if (this.structs && this.structs[wgslType]) {
const structInfo = this.structs[wgslType];
const view = new DataView(copyArrayBuffer);
const resultObj = {};
for (const field of structInfo.fields) {
const off = field.offset;
switch (field.type) {
case "u32":
resultObj[field.name] = view.getUint32(off, true);
break;
case "i32":
resultObj[field.name] = view.getInt32(off, true);
break;
case "f32":
resultObj[field.name] = view.getFloat32(off, true);
break;
default:
console.warn("Unhandled struct field type:", field.type);
resultObj[field.name] = null;
break;
}
}
debugReadBuffer.unmap();
return resultObj;
}
switch ( wgslType ) {
case "array<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "array<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "vec3<f32>":
break;
case "vec4<f32>":
result = new Float32Array( copyArrayBuffer );
break;
case "u32":
case "vec2<u32>":
case "vec3<u32>":
case "vec4<u32>":
result = new Uint32Array( copyArrayBuffer );
break;
case "Tensor":
result = new Float32Array(copyArrayBuffer);
break;
case "array<vec4<f32>>":
case "array<Vector>":
case "array<array<f32, 64>>":
case "array<vec3<f32>>":
result = new Float32Array( copyArrayBuffer );
break;
case "vec3<i32>":
case "vec4<i32>":
result = new Int32Array( copyArrayBuffer );
break;
default:
console.error( "Unknown WGSL type in setVariable: " + wgslType, "Using Float32Array as fallback." );
result = new Float32Array( copyArrayBuffer );
}
// use wgslType here to determine appropriate TypedArray
}
const length = result.length;
var regularArray = new Array();
for (var i = 0; i < length; i++) {
regularArray.push(result[i]);
}
/*
console.log("Debugging array");
console.log("");
console.log("Shader: ", this.path);
console.log("Variable Name: ", bufferName);
console.log( "Buffer Size: ", this.getBuffer(bufferName).size );
console.log( "TypedArray Size: ", length );
*/
if ( bindingInfo ) {
const wgslType = bindingInfo.varType;
//console.log( "type: ", wgslType );
}
//console.log(regularArray);
debugReadBuffer.unmap();
return regularArray;
}
}