Files
WebGPU-Neural-Network-Trainer/README.md

234 lines
4.8 KiB
Markdown
Raw Normal View History

2025-11-19 10:55:49 +01:00
# WebGPU Neural Network Trainer
A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**.
This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with:
* Forward pass (dense layer 1 + ReLU)
* Forward pass (dense layer 2 → logits)
* Softmax + cross-entropy loss
* Backpropagation for both dense layers
* Weight/bias gradient computation
* SGD optimizer update
* Live accuracy + loss visualization
* Browser-side image preview of predictions
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
---
## Features
### ✔ Fully WebGPU-accelerated training
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
### ✔ End-to-end MNIST pipeline
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
### ✔ Pure WGSL neural network
Six compute shaders:
1. `forward_dense_layer1.wgsl`
2. `forward_dense_layer2.wgsl`
3. `softmax_cross_entropy_backward.wgsl`
4. `gradients_dense_layer2.wgsl`
5. `gradients_dense_layer1.wgsl`
6. `sgd_optimizer_update.wgsl`
### ✔ Pure JS orchestration
A lightweight JS engine (`training2.js`) handles:
* shader setup
* buffer binding
* dispatch sizes
* ordered forward/backward passes
* debug readbacks
* canvas preview of sample predictions
### ✔ Modular WebGPU runtime
The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically:
* parses WGSL bindings
* allocates buffers
* recreates bind groups
* handles resizing
* dispatches compute passes
* reads debug buffers
---
## Live Results
A typical run with batch size **64**:
```
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
```
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
---
## Project Structure
```
WebGPU-Neural-Network-Trainer/
├── framework/
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
├── shaders/neural/
│ ├── forward_dense_layer1.wgsl
│ ├── forward_dense_layer2.wgsl
│ ├── softmax_cross_entropy_backward.wgsl
│ ├── gradients_dense_layer2.wgsl
│ ├── gradients_dense_layer1.wgsl
│ └── sgd_optimizer_update.wgsl
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
├── training.js # Main training pipeline
2025-11-19 10:57:23 +01:00
├── server.js # Server for the demo -> visit localhost:2020
2025-11-19 10:55:49 +01:00
└── index.html # UI + canvas preview + training controls
```
---
## How It Works
### 1. Load MNIST
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
### 2. Forward Pass
`forward_dense_layer1.wgsl`
* matrix multiply: X·W₁ + b₁
* activation: ReLU
`forward_dense_layer2.wgsl`
* matrix multiply: H·W₂ + b₂
### 3. Loss + Gradient w.r.t. logits
`softmax_cross_entropy_backward.wgsl`
Produces:
* gradientLogits[b * numClasses + c]
* crossEntropyLoss[b]
### 4. Gradients for Dense Layer 2
`gradients_dense_layer2.wgsl`
Computes:
* dW₂
* dB₂
* dHidden (backprop into hidden layer)
### 5. Gradients for Dense Layer 1
`gradients_dense_layer1.wgsl`
Computes:
* dW₁
* dB₁
### 6. SGD Update
`sgd_optimizer_update.wgsl`
Updates all weights and biases:
```
w -= learningRate * dw
b -= learningRate * db
```
---
## Running the Project
### Requirements
* A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
* Local HTTP server (Chrome blocks file:// GPU access)
### Start a server
Example:
```
node server.js
```
or
```
python3 -m http.server
```
or
```
npx http-server .
```
Then open:
```
2025-11-19 10:57:23 +01:00
http://localhost:2020
2025-11-19 10:55:49 +01:00
```
Press **Train** — the network trains in real time, showing:
* Loss
* Accuracy
* Six random preview samples with predicted labels
---
## Why This Project Matters
This repository demonstrates:
* understanding of GPU buffer layouts
* compute pipeline dispatching
* custom WGSL kernels for ML
* manual backprop and gradient math
* GPU memory management and buffer resizing
* a full deep-learning training loop without any frameworks
It is strong evidence of skill in:
**AI systems engineering**,
**WebGPU compute design**,
**neural-network internals**,
and **performance-oriented GPU development**.
These skills are directly relevant for:
* AI inference runtimes
* GPU acceleration teams
* Web-based model execution
* systems-level optimization roles
---
## License
MIT