234 lines
4.8 KiB
Markdown
234 lines
4.8 KiB
Markdown
# WebGPU Neural Network Trainer
|
||
|
||
A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**.
|
||
This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with:
|
||
|
||
* Forward pass (dense layer 1 + ReLU)
|
||
* Forward pass (dense layer 2 → logits)
|
||
* Softmax + cross-entropy loss
|
||
* Backpropagation for both dense layers
|
||
* Weight/bias gradient computation
|
||
* SGD optimizer update
|
||
* Live accuracy + loss visualization
|
||
* Browser-side image preview of predictions
|
||
|
||
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
|
||
|
||
---
|
||
|
||
## Features
|
||
|
||
### ✔ Fully WebGPU-accelerated training
|
||
|
||
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
|
||
|
||
### ✔ End-to-end MNIST pipeline
|
||
|
||
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
|
||
|
||
### ✔ Pure WGSL neural network
|
||
|
||
Six compute shaders:
|
||
|
||
1. `forward_dense_layer1.wgsl`
|
||
2. `forward_dense_layer2.wgsl`
|
||
3. `softmax_cross_entropy_backward.wgsl`
|
||
4. `gradients_dense_layer2.wgsl`
|
||
5. `gradients_dense_layer1.wgsl`
|
||
6. `sgd_optimizer_update.wgsl`
|
||
|
||
### ✔ Pure JS orchestration
|
||
|
||
A lightweight JS engine (`training2.js`) handles:
|
||
|
||
* shader setup
|
||
* buffer binding
|
||
* dispatch sizes
|
||
* ordered forward/backward passes
|
||
* debug readbacks
|
||
* canvas preview of sample predictions
|
||
|
||
### ✔ Modular WebGPU runtime
|
||
|
||
The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically:
|
||
|
||
* parses WGSL bindings
|
||
* allocates buffers
|
||
* recreates bind groups
|
||
* handles resizing
|
||
* dispatches compute passes
|
||
* reads debug buffers
|
||
|
||
---
|
||
|
||
## Live Results
|
||
|
||
A typical run with batch size **64**:
|
||
|
||
```
|
||
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
|
||
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
|
||
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
|
||
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
|
||
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
|
||
```
|
||
|
||
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
|
||
|
||
---
|
||
|
||
## Project Structure
|
||
|
||
```
|
||
WebGPU-Neural-Network-Trainer/
|
||
│
|
||
├── framework/
|
||
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
|
||
│
|
||
├── shaders/neural/
|
||
│ ├── forward_dense_layer1.wgsl
|
||
│ ├── forward_dense_layer2.wgsl
|
||
│ ├── softmax_cross_entropy_backward.wgsl
|
||
│ ├── gradients_dense_layer2.wgsl
|
||
│ ├── gradients_dense_layer1.wgsl
|
||
│ └── sgd_optimizer_update.wgsl
|
||
│
|
||
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
|
||
│
|
||
├── training.js # Main training pipeline
|
||
│
|
||
├── server.js # Server for the demo -> visit localhost:2020
|
||
│
|
||
└── index.html # UI + canvas preview + training controls
|
||
```
|
||
|
||
---
|
||
|
||
## How It Works
|
||
|
||
### 1. Load MNIST
|
||
|
||
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
|
||
|
||
### 2. Forward Pass
|
||
|
||
`forward_dense_layer1.wgsl`
|
||
|
||
* matrix multiply: X·W₁ + b₁
|
||
* activation: ReLU
|
||
|
||
`forward_dense_layer2.wgsl`
|
||
|
||
* matrix multiply: H·W₂ + b₂
|
||
|
||
### 3. Loss + Gradient w.r.t. logits
|
||
|
||
`softmax_cross_entropy_backward.wgsl`
|
||
Produces:
|
||
|
||
* gradientLogits[b * numClasses + c]
|
||
* crossEntropyLoss[b]
|
||
|
||
### 4. Gradients for Dense Layer 2
|
||
|
||
`gradients_dense_layer2.wgsl`
|
||
Computes:
|
||
|
||
* dW₂
|
||
* dB₂
|
||
* dHidden (backprop into hidden layer)
|
||
|
||
### 5. Gradients for Dense Layer 1
|
||
|
||
`gradients_dense_layer1.wgsl`
|
||
Computes:
|
||
|
||
* dW₁
|
||
* dB₁
|
||
|
||
### 6. SGD Update
|
||
|
||
`sgd_optimizer_update.wgsl`
|
||
Updates all weights and biases:
|
||
|
||
```
|
||
w -= learningRate * dw
|
||
b -= learningRate * db
|
||
```
|
||
|
||
---
|
||
|
||
## Running the Project
|
||
|
||
### Requirements
|
||
|
||
* A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
|
||
* Local HTTP server (Chrome blocks file:// GPU access)
|
||
|
||
### Start a server
|
||
|
||
Example:
|
||
|
||
|
||
```
|
||
node server.js
|
||
```
|
||
|
||
or
|
||
|
||
```
|
||
python3 -m http.server
|
||
```
|
||
|
||
or
|
||
|
||
```
|
||
npx http-server .
|
||
```
|
||
|
||
Then open:
|
||
|
||
```
|
||
http://localhost:2020
|
||
```
|
||
|
||
Press **Train** — the network trains in real time, showing:
|
||
|
||
* Loss
|
||
* Accuracy
|
||
* Six random preview samples with predicted labels
|
||
|
||
---
|
||
|
||
## Why This Project Matters
|
||
|
||
This repository demonstrates:
|
||
|
||
* understanding of GPU buffer layouts
|
||
* compute pipeline dispatching
|
||
* custom WGSL kernels for ML
|
||
* manual backprop and gradient math
|
||
* GPU memory management and buffer resizing
|
||
* a full deep-learning training loop without any frameworks
|
||
|
||
It is strong evidence of skill in:
|
||
|
||
**AI systems engineering**,
|
||
**WebGPU compute design**,
|
||
**neural-network internals**,
|
||
and **performance-oriented GPU development**.
|
||
|
||
These skills are directly relevant for:
|
||
|
||
* AI inference runtimes
|
||
* GPU acceleration teams
|
||
* Web-based model execution
|
||
* systems-level optimization roles
|
||
|
||
---
|
||
|
||
## License
|
||
|
||
MIT
|
||
|