Added Readme.
This commit is contained in:
233
README.md
Normal file
233
README.md
Normal file
@@ -0,0 +1,233 @@
|
||||
# WebGPU Neural Network Trainer
|
||||
|
||||
A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**.
|
||||
This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with:
|
||||
|
||||
* Forward pass (dense layer 1 + ReLU)
|
||||
* Forward pass (dense layer 2 → logits)
|
||||
* Softmax + cross-entropy loss
|
||||
* Backpropagation for both dense layers
|
||||
* Weight/bias gradient computation
|
||||
* SGD optimizer update
|
||||
* Live accuracy + loss visualization
|
||||
* Browser-side image preview of predictions
|
||||
|
||||
No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
### ✔ Fully WebGPU-accelerated training
|
||||
|
||||
All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.
|
||||
|
||||
### ✔ End-to-end MNIST pipeline
|
||||
|
||||
Images and labels are uploaded to GPU buffers and remain on GPU throughout training.
|
||||
|
||||
### ✔ Pure WGSL neural network
|
||||
|
||||
Six compute shaders:
|
||||
|
||||
1. `forward_dense_layer1.wgsl`
|
||||
2. `forward_dense_layer2.wgsl`
|
||||
3. `softmax_cross_entropy_backward.wgsl`
|
||||
4. `gradients_dense_layer2.wgsl`
|
||||
5. `gradients_dense_layer1.wgsl`
|
||||
6. `sgd_optimizer_update.wgsl`
|
||||
|
||||
### ✔ Pure JS orchestration
|
||||
|
||||
A lightweight JS engine (`training2.js`) handles:
|
||||
|
||||
* shader setup
|
||||
* buffer binding
|
||||
* dispatch sizes
|
||||
* ordered forward/backward passes
|
||||
* debug readbacks
|
||||
* canvas preview of sample predictions
|
||||
|
||||
### ✔ Modular WebGPU runtime
|
||||
|
||||
The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically:
|
||||
|
||||
* parses WGSL bindings
|
||||
* allocates buffers
|
||||
* recreates bind groups
|
||||
* handles resizing
|
||||
* dispatches compute passes
|
||||
* reads debug buffers
|
||||
|
||||
---
|
||||
|
||||
## Live Results
|
||||
|
||||
A typical run with batch size **64**:
|
||||
|
||||
```
|
||||
Epoch 0: Loss = 2.3032 Accuracy = 9.40%
|
||||
Epoch 1: Loss = 2.1380 Accuracy = 26.40%
|
||||
Epoch 5: Loss = 1.8169 Accuracy = 36.90%
|
||||
Epoch 10: Loss = 1.6810 Accuracy = 40.60%
|
||||
Epoch 19: Loss = 1.6004 Accuracy = 41.60%
|
||||
```
|
||||
|
||||
This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
WebGPU-Neural-Network-Trainer/
|
||||
│
|
||||
├── framework/
|
||||
│ └── WebGpu.js # Automated WebGPU shader/buffer manager
|
||||
│
|
||||
├── shaders/neural/
|
||||
│ ├── forward_dense_layer1.wgsl
|
||||
│ ├── forward_dense_layer2.wgsl
|
||||
│ ├── softmax_cross_entropy_backward.wgsl
|
||||
│ ├── gradients_dense_layer2.wgsl
|
||||
│ ├── gradients_dense_layer1.wgsl
|
||||
│ └── sgd_optimizer_update.wgsl
|
||||
│
|
||||
├── mnist_loader.js # Loads MNIST images + labels into GPU buffers
|
||||
│
|
||||
├── training.js # Main training pipeline
|
||||
│
|
||||
├── server.js # Server for the demo -> visit localhost:3030
|
||||
│
|
||||
└── index.html # UI + canvas preview + training controls
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Load MNIST
|
||||
|
||||
Images (28×28 float) and labels (uint32) are uploaded into storage buffers.
|
||||
|
||||
### 2. Forward Pass
|
||||
|
||||
`forward_dense_layer1.wgsl`
|
||||
|
||||
* matrix multiply: X·W₁ + b₁
|
||||
* activation: ReLU
|
||||
|
||||
`forward_dense_layer2.wgsl`
|
||||
|
||||
* matrix multiply: H·W₂ + b₂
|
||||
|
||||
### 3. Loss + Gradient w.r.t. logits
|
||||
|
||||
`softmax_cross_entropy_backward.wgsl`
|
||||
Produces:
|
||||
|
||||
* gradientLogits[b * numClasses + c]
|
||||
* crossEntropyLoss[b]
|
||||
|
||||
### 4. Gradients for Dense Layer 2
|
||||
|
||||
`gradients_dense_layer2.wgsl`
|
||||
Computes:
|
||||
|
||||
* dW₂
|
||||
* dB₂
|
||||
* dHidden (backprop into hidden layer)
|
||||
|
||||
### 5. Gradients for Dense Layer 1
|
||||
|
||||
`gradients_dense_layer1.wgsl`
|
||||
Computes:
|
||||
|
||||
* dW₁
|
||||
* dB₁
|
||||
|
||||
### 6. SGD Update
|
||||
|
||||
`sgd_optimizer_update.wgsl`
|
||||
Updates all weights and biases:
|
||||
|
||||
```
|
||||
w -= learningRate * dw
|
||||
b -= learningRate * db
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Running the Project
|
||||
|
||||
### Requirements
|
||||
|
||||
* A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
|
||||
* Local HTTP server (Chrome blocks file:// GPU access)
|
||||
|
||||
### Start a server
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
```
|
||||
node server.js
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
python3 -m http.server
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
npx http-server .
|
||||
```
|
||||
|
||||
Then open:
|
||||
|
||||
```
|
||||
http://localhost:3030
|
||||
```
|
||||
|
||||
Press **Train** — the network trains in real time, showing:
|
||||
|
||||
* Loss
|
||||
* Accuracy
|
||||
* Six random preview samples with predicted labels
|
||||
|
||||
---
|
||||
|
||||
## Why This Project Matters
|
||||
|
||||
This repository demonstrates:
|
||||
|
||||
* understanding of GPU buffer layouts
|
||||
* compute pipeline dispatching
|
||||
* custom WGSL kernels for ML
|
||||
* manual backprop and gradient math
|
||||
* GPU memory management and buffer resizing
|
||||
* a full deep-learning training loop without any frameworks
|
||||
|
||||
It is strong evidence of skill in:
|
||||
|
||||
**AI systems engineering**,
|
||||
**WebGPU compute design**,
|
||||
**neural-network internals**,
|
||||
and **performance-oriented GPU development**.
|
||||
|
||||
These skills are directly relevant for:
|
||||
|
||||
* AI inference runtimes
|
||||
* GPU acceleration teams
|
||||
* Web-based model execution
|
||||
* systems-level optimization roles
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
Reference in New Issue
Block a user