# WebGPU Neural Network Trainer A fully GPU-accelerated neural-network training pipeline implemented in **WebGPU** using custom **WGSL compute shaders**. This project trains a **2-layer MLP classifier** on MNIST entirely inside the browser with: * Forward pass (dense layer 1 + ReLU) * Forward pass (dense layer 2 → logits) * Softmax + cross-entropy loss * Backpropagation for both dense layers * Weight/bias gradient computation * SGD optimizer update * Live accuracy + loss visualization * Browser-side image preview of predictions No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders. --- ## Features ### ✔ Fully WebGPU-accelerated training All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels. ### ✔ End-to-end MNIST pipeline Images and labels are uploaded to GPU buffers and remain on GPU throughout training. ### ✔ Pure WGSL neural network Six compute shaders: 1. `forward_dense_layer1.wgsl` 2. `forward_dense_layer2.wgsl` 3. `softmax_cross_entropy_backward.wgsl` 4. `gradients_dense_layer2.wgsl` 5. `gradients_dense_layer1.wgsl` 6. `sgd_optimizer_update.wgsl` ### ✔ Pure JS orchestration A lightweight JS engine (`training2.js`) handles: * shader setup * buffer binding * dispatch sizes * ordered forward/backward passes * debug readbacks * canvas preview of sample predictions ### ✔ Modular WebGPU runtime The training pipeline uses a generic WebGPU shader wrapper (`WebGpu.js`) that automatically: * parses WGSL bindings * allocates buffers * recreates bind groups * handles resizing * dispatches compute passes * reads debug buffers --- ## Live Results A typical run with batch size **64**: ``` Epoch 0: Loss = 2.3032 Accuracy = 9.40% Epoch 1: Loss = 2.1380 Accuracy = 26.40% Epoch 5: Loss = 1.8169 Accuracy = 36.90% Epoch 10: Loss = 1.6810 Accuracy = 40.60% Epoch 19: Loss = 1.6004 Accuracy = 41.60% ``` This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly. --- ## Project Structure ``` WebGPU-Neural-Network-Trainer/ │ ├── framework/ │ └── WebGpu.js # Automated WebGPU shader/buffer manager │ ├── shaders/neural/ │ ├── forward_dense_layer1.wgsl │ ├── forward_dense_layer2.wgsl │ ├── softmax_cross_entropy_backward.wgsl │ ├── gradients_dense_layer2.wgsl │ ├── gradients_dense_layer1.wgsl │ └── sgd_optimizer_update.wgsl │ ├── mnist_loader.js # Loads MNIST images + labels into GPU buffers │ ├── training.js # Main training pipeline │ ├── server.js # Server for the demo -> visit localhost:2020 │ └── index.html # UI + canvas preview + training controls ``` --- ## How It Works ### 1. Load MNIST Images (28×28 float) and labels (uint32) are uploaded into storage buffers. ### 2. Forward Pass `forward_dense_layer1.wgsl` * matrix multiply: X·W₁ + b₁ * activation: ReLU `forward_dense_layer2.wgsl` * matrix multiply: H·W₂ + b₂ ### 3. Loss + Gradient w.r.t. logits `softmax_cross_entropy_backward.wgsl` Produces: * gradientLogits[b * numClasses + c] * crossEntropyLoss[b] ### 4. Gradients for Dense Layer 2 `gradients_dense_layer2.wgsl` Computes: * dW₂ * dB₂ * dHidden (backprop into hidden layer) ### 5. Gradients for Dense Layer 1 `gradients_dense_layer1.wgsl` Computes: * dW₁ * dB₁ ### 6. SGD Update `sgd_optimizer_update.wgsl` Updates all weights and biases: ``` w -= learningRate * dw b -= learningRate * db ``` --- ## Running the Project ### Requirements * A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge) * Local HTTP server (Chrome blocks file:// GPU access) ### Start a server Example: ``` node server.js ``` or ``` python3 -m http.server ``` or ``` npx http-server . ``` Then open: ``` http://localhost:2020 ``` Press **Train** — the network trains in real time, showing: * Loss * Accuracy * Six random preview samples with predicted labels --- ## Why This Project Matters This repository demonstrates: * understanding of GPU buffer layouts * compute pipeline dispatching * custom WGSL kernels for ML * manual backprop and gradient math * GPU memory management and buffer resizing * a full deep-learning training loop without any frameworks It is strong evidence of skill in: **AI systems engineering**, **WebGPU compute design**, **neural-network internals**, and **performance-oriented GPU development**. These skills are directly relevant for: * AI inference runtimes * GPU acceleration teams * Web-based model execution * systems-level optimization roles --- ## License MIT