Files
WebGPU-Neural-Network-Trainer/README.md
2025-11-24 18:56:19 +01:00

209 lines
4.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# WebGPU Neural Network Trainer
A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**.
This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with:
* Forward pass (dense layer 1 + ReLU)
* Forward pass (dense layer 2 → logits)
* Softmax + cross-entropy loss
* Backpropagation for both dense layers
* Weight/bias gradient computation
* SGD optimizer update
* Live accuracy + loss visualization
* Browser-side image preview of predictions
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
---
## Features
### ✔ Fully WebGPU-accelerated training
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
### ✔ End-to-end MNIST pipeline
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
### ✔ Pure WGSL neural network
Six compute shaders:
1. `forward_dense_layer1.wgsl`
2. `forward_dense_layer2.wgsl`
3. `softmax_cross_entropy_backward.wgsl`
4. `gradients_dense_layer2.wgsl`
5. `gradients_dense_layer1.wgsl`
6. `sgd_optimizer_update.wgsl`
### ✔ Pure JS orchestration
A lightweight JS engine (`training2.js`) handles:
* shader setup
* buffer binding
* dispatch sizes
* ordered forward/backward passes
* debug readbacks
* canvas preview of sample predictions
### ✔ Modular WebGPU runtime
The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically:
* parses WGSL bindings
* allocates buffers
* recreates bind groups
* handles resizing
* dispatches compute passes
* reads debug buffers
---
## Live Results
A typical run with batch size **64**:
```
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
```
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
---
## Project Structure
```
WebGPU-Neural-Network-Trainer/
├── framework/
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
├── shaders/neural/
│ ├── forward_dense_layer1.wgsl
│ ├── forward_dense_layer2.wgsl
│ ├── softmax_cross_entropy_backward.wgsl
│ ├── gradients_dense_layer2.wgsl
│ ├── gradients_dense_layer1.wgsl
│ └── sgd_optimizer_update.wgsl
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
├── training.js # Main training pipeline
├── server.js # Server for the demo -> visit localhost:2020
└── index.html # UI + canvas preview + training controls
```
---
## How It Works
### 1. Load MNIST
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
### 2. Forward Pass
`forward_dense_layer1.wgsl`
* matrix multiply: X·W₁ + b₁
* activation: ReLU
`forward_dense_layer2.wgsl`
* matrix multiply: H·W₂ + b₂
### 3. Loss + Gradient w.r.t. logits
`softmax_cross_entropy_backward.wgsl`
Produces:
* gradientLogits[b * numClasses + c]
* crossEntropyLoss[b]
### 4. Gradients for Dense Layer 2
`gradients_dense_layer2.wgsl`
Computes:
* dW₂
* dB₂
* dHidden (backprop into hidden layer)
### 5. Gradients for Dense Layer 1
`gradients_dense_layer1.wgsl`
Computes:
* dW₁
* dB₁
### 6. SGD Update
`sgd_optimizer_update.wgsl`
Updates all weights and biases:
```
w -= learningRate * dw
b -= learningRate * db
```
---
## Running the Project
### Requirements
* A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
* Local HTTP server (Chrome blocks file:// GPU access)
### Start a server
Example:
```
node server.js
```
or
```
python3 -m http.server
```
or
```
npx http-server .
```
Then open:
```
http://localhost:2020
```
Press **Train** — the network trains in real time, showing:
* Loss
* Accuracy
* Six random preview samples with predicted labels
---
## License
MIT