Added Readme.
This commit is contained in:
233
README.md
Normal file
233
README.md
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
# WebGPU Neural Network Trainer
|
||||||
|
|
||||||
|
A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**.
|
||||||
|
This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with:
|
||||||
|
|
||||||
|
* Forward pass (dense layer 1 + ReLU)
|
||||||
|
* Forward pass (dense layer 2 → logits)
|
||||||
|
* Softmax + cross-entropy loss
|
||||||
|
* Backpropagation for both dense layers
|
||||||
|
* Weight/bias gradient computation
|
||||||
|
* SGD optimizer update
|
||||||
|
* Live accuracy + loss visualization
|
||||||
|
* Browser-side image preview of predictions
|
||||||
|
|
||||||
|
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### ✔ Fully WebGPU-accelerated training
|
||||||
|
|
||||||
|
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
|
||||||
|
|
||||||
|
### ✔ End-to-end MNIST pipeline
|
||||||
|
|
||||||
|
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
|
||||||
|
|
||||||
|
### ✔ Pure WGSL neural network
|
||||||
|
|
||||||
|
Six compute shaders:
|
||||||
|
|
||||||
|
1. `forward_dense_layer1.wgsl`
|
||||||
|
2. `forward_dense_layer2.wgsl`
|
||||||
|
3. `softmax_cross_entropy_backward.wgsl`
|
||||||
|
4. `gradients_dense_layer2.wgsl`
|
||||||
|
5. `gradients_dense_layer1.wgsl`
|
||||||
|
6. `sgd_optimizer_update.wgsl`
|
||||||
|
|
||||||
|
### ✔ Pure JS orchestration
|
||||||
|
|
||||||
|
A lightweight JS engine (`training2.js`) handles:
|
||||||
|
|
||||||
|
* shader setup
|
||||||
|
* buffer binding
|
||||||
|
* dispatch sizes
|
||||||
|
* ordered forward/backward passes
|
||||||
|
* debug readbacks
|
||||||
|
* canvas preview of sample predictions
|
||||||
|
|
||||||
|
### ✔ Modular WebGPU runtime
|
||||||
|
|
||||||
|
The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically:
|
||||||
|
|
||||||
|
* parses WGSL bindings
|
||||||
|
* allocates buffers
|
||||||
|
* recreates bind groups
|
||||||
|
* handles resizing
|
||||||
|
* dispatches compute passes
|
||||||
|
* reads debug buffers
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Live Results
|
||||||
|
|
||||||
|
A typical run with batch size **64**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
|
||||||
|
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
|
||||||
|
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
|
||||||
|
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
|
||||||
|
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
|
||||||
|
```
|
||||||
|
|
||||||
|
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
WebGPU-Neural-Network-Trainer/
|
||||||
|
│
|
||||||
|
├── framework/
|
||||||
|
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
|
||||||
|
│
|
||||||
|
├── shaders/neural/
|
||||||
|
│ ├── forward_dense_layer1.wgsl
|
||||||
|
│ ├── forward_dense_layer2.wgsl
|
||||||
|
│ ├── softmax_cross_entropy_backward.wgsl
|
||||||
|
│ ├── gradients_dense_layer2.wgsl
|
||||||
|
│ ├── gradients_dense_layer1.wgsl
|
||||||
|
│ └── sgd_optimizer_update.wgsl
|
||||||
|
│
|
||||||
|
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
|
||||||
|
│
|
||||||
|
├── training.js # Main training pipeline
|
||||||
|
│
|
||||||
|
├── server.js # Server for the demo -> visit localhost:3030
|
||||||
|
│
|
||||||
|
└── index.html # UI + canvas preview + training controls
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### 1. Load MNIST
|
||||||
|
|
||||||
|
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
|
||||||
|
|
||||||
|
### 2. Forward Pass
|
||||||
|
|
||||||
|
`forward_dense_layer1.wgsl`
|
||||||
|
|
||||||
|
* matrix multiply: X·W₁ + b₁
|
||||||
|
* activation: ReLU
|
||||||
|
|
||||||
|
`forward_dense_layer2.wgsl`
|
||||||
|
|
||||||
|
* matrix multiply: H·W₂ + b₂
|
||||||
|
|
||||||
|
### 3. Loss + Gradient w.r.t. logits
|
||||||
|
|
||||||
|
`softmax_cross_entropy_backward.wgsl`
|
||||||
|
Produces:
|
||||||
|
|
||||||
|
* gradientLogits[b * numClasses + c]
|
||||||
|
* crossEntropyLoss[b]
|
||||||
|
|
||||||
|
### 4. Gradients for Dense Layer 2
|
||||||
|
|
||||||
|
`gradients_dense_layer2.wgsl`
|
||||||
|
Computes:
|
||||||
|
|
||||||
|
* dW₂
|
||||||
|
* dB₂
|
||||||
|
* dHidden (backprop into hidden layer)
|
||||||
|
|
||||||
|
### 5. Gradients for Dense Layer 1
|
||||||
|
|
||||||
|
`gradients_dense_layer1.wgsl`
|
||||||
|
Computes:
|
||||||
|
|
||||||
|
* dW₁
|
||||||
|
* dB₁
|
||||||
|
|
||||||
|
### 6. SGD Update
|
||||||
|
|
||||||
|
`sgd_optimizer_update.wgsl`
|
||||||
|
Updates all weights and biases:
|
||||||
|
|
||||||
|
```
|
||||||
|
w -= learningRate * dw
|
||||||
|
b -= learningRate * db
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Running the Project
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
* A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
|
||||||
|
* Local HTTP server (Chrome blocks file:// GPU access)
|
||||||
|
|
||||||
|
### Start a server
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
node server.js
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m http.server
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
npx http-server .
|
||||||
|
```
|
||||||
|
|
||||||
|
Then open:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://localhost:3030
|
||||||
|
```
|
||||||
|
|
||||||
|
Press **Train** — the network trains in real time, showing:
|
||||||
|
|
||||||
|
* Loss
|
||||||
|
* Accuracy
|
||||||
|
* Six random preview samples with predicted labels
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why This Project Matters
|
||||||
|
|
||||||
|
This repository demonstrates:
|
||||||
|
|
||||||
|
* understanding of GPU buffer layouts
|
||||||
|
* compute pipeline dispatching
|
||||||
|
* custom WGSL kernels for ML
|
||||||
|
* manual backprop and gradient math
|
||||||
|
* GPU memory management and buffer resizing
|
||||||
|
* a full deep-learning training loop without any frameworks
|
||||||
|
|
||||||
|
It is strong evidence of skill in:
|
||||||
|
|
||||||
|
**AI systems engineering**,
|
||||||
|
**WebGPU compute design**,
|
||||||
|
**neural-network internals**,
|
||||||
|
and **performance-oriented GPU development**.
|
||||||
|
|
||||||
|
These skills are directly relevant for:
|
||||||
|
|
||||||
|
* AI inference runtimes
|
||||||
|
* GPU acceleration teams
|
||||||
|
* Web-based model execution
|
||||||
|
* systems-level optimization roles
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
|
|
||||||
Reference in New Issue
Block a user