WebGPU Neural Network Trainer
A fully GPU-accelerated neural-network training pipeline implemented in WebGPU using custom WGSL compute shaders. This project trains a 2-layer MLP classifier on MNIST entirely inside the browser with:
- Forward pass (dense layer 1 + ReLU)
- Forward pass (dense layer 2 → logits)
- Softmax + cross-entropy loss
- Backpropagation for both dense layers
- Weight/bias gradient computation
- SGD optimizer update
- Live accuracy + loss visualization
- Browser-side image preview of predictions
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
Features
✔ Fully WebGPU-accelerated training
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
✔ End-to-end MNIST pipeline
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
✔ Pure WGSL neural network
Six compute shaders:
forward_dense_layer1.wgslforward_dense_layer2.wgslsoftmax_cross_entropy_backward.wgslgradients_dense_layer2.wgslgradients_dense_layer1.wgslsgd_optimizer_update.wgsl
✔ Pure JS orchestration
A lightweight JS engine (training2.js) handles:
- shader setup
- buffer binding
- dispatch sizes
- ordered forward/backward passes
- debug readbacks
- canvas preview of sample predictions
✔ Modular WebGPU runtime
The training pipeline uses a generic WebGPU shader wrapper (WebGpu.js) that automatically:
- parses WGSL bindings
- allocates buffers
- recreates bind groups
- handles resizing
- dispatches compute passes
- reads debug buffers
Live Results
A typical run with batch size 64:
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
Project Structure
WebGPU-Neural-Network-Trainer/
│
├── framework/
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
│
├── shaders/neural/
│ ├── forward_dense_layer1.wgsl
│ ├── forward_dense_layer2.wgsl
│ ├── softmax_cross_entropy_backward.wgsl
│ ├── gradients_dense_layer2.wgsl
│ ├── gradients_dense_layer1.wgsl
│ └── sgd_optimizer_update.wgsl
│
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
│
├── training.js # Main training pipeline
│
├── server.js # Server for the demo -> visit localhost:2020
│
└── index.html # UI + canvas preview + training controls
How It Works
1. Load MNIST
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
2. Forward Pass
forward_dense_layer1.wgsl
- matrix multiply: X·W₁ + b₁
- activation: ReLU
forward_dense_layer2.wgsl
- matrix multiply: H·W₂ + b₂
3. Loss + Gradient w.r.t. logits
softmax_cross_entropy_backward.wgsl
Produces:
- gradientLogits[b * numClasses + c]
- crossEntropyLoss[b]
4. Gradients for Dense Layer 2
gradients_dense_layer2.wgsl
Computes:
- dW₂
- dB₂
- dHidden (backprop into hidden layer)
5. Gradients for Dense Layer 1
gradients_dense_layer1.wgsl
Computes:
- dW₁
- dB₁
6. SGD Update
sgd_optimizer_update.wgsl
Updates all weights and biases:
w -= learningRate * dw
b -= learningRate * db
Running the Project
Requirements
- A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
- Local HTTP server (Chrome blocks file:// GPU access)
Start a server
Example:
node server.js
or
python3 -m http.server
or
npx http-server .
Then open:
http://localhost:2020
Press Train — the network trains in real time, showing:
- Loss
- Accuracy
- Six random preview samples with predicted labels
License
MIT