2025-11-19 10:55:49 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:55:49 +01:00
2025-11-19 10:42:46 +01:00
2025-11-19 10:42:46 +01:00

WebGPU Neural Network Trainer

A fully GPU-accelerated neural-network training pipeline implemented in WebGPU using custom WGSL compute shaders. This project trains a 2-layer MLP classifier on MNIST entirely inside the browser with:

  • Forward pass (dense layer 1 + ReLU)
  • Forward pass (dense layer 2 → logits)
  • Softmax + cross-entropy loss
  • Backpropagation for both dense layers
  • Weight/bias gradient computation
  • SGD optimizer update
  • Live accuracy + loss visualization
  • Browser-side image preview of predictions

No ML frameworks and no libraries — every compute pass, buffer, math operation, and gradient update runs through raw WebGPU shaders.


Features

✔ Fully WebGPU-accelerated training

All computation (forward pass, backward pass, loss, optimizer) is written in WGSL compute kernels.

✔ End-to-end MNIST pipeline

Images and labels are uploaded to GPU buffers and remain on GPU throughout training.

✔ Pure WGSL neural network

Six compute shaders:

  1. forward_dense_layer1.wgsl
  2. forward_dense_layer2.wgsl
  3. softmax_cross_entropy_backward.wgsl
  4. gradients_dense_layer2.wgsl
  5. gradients_dense_layer1.wgsl
  6. sgd_optimizer_update.wgsl

✔ Pure JS orchestration

A lightweight JS engine (training2.js) handles:

  • shader setup
  • buffer binding
  • dispatch sizes
  • ordered forward/backward passes
  • debug readbacks
  • canvas preview of sample predictions

✔ Modular WebGPU runtime

The training pipeline uses a generic WebGPU shader wrapper (WebGpu.js) that automatically:

  • parses WGSL bindings
  • allocates buffers
  • recreates bind groups
  • handles resizing
  • dispatches compute passes
  • reads debug buffers

Live Results

A typical run with batch size 64:

Epoch 0: Loss = 2.3032   Accuracy = 9.40%
Epoch 1: Loss = 2.1380   Accuracy = 26.40%
Epoch 5: Loss = 1.8169   Accuracy = 36.90%
Epoch 10: Loss = 1.6810  Accuracy = 40.60%
Epoch 19: Loss = 1.6004  Accuracy = 41.60%

This matches what you'd expect from a small, unoptimized 128-unit MLP trained only on a single batch — validating that all gradients and updates propagate correctly.


Project Structure

WebGPU-Neural-Network-Trainer/
│
├── framework/
│   └── WebGpu.js                     # Automated WebGPU shader/buffer manager
│
├── shaders/neural/
│   ├── forward_dense_layer1.wgsl
│   ├── forward_dense_layer2.wgsl
│   ├── softmax_cross_entropy_backward.wgsl
│   ├── gradients_dense_layer2.wgsl
│   ├── gradients_dense_layer1.wgsl
│   └── sgd_optimizer_update.wgsl
│
├── mnist_loader.js            # Loads MNIST images + labels into GPU buffers
│
├── training.js                # Main training pipeline
│
├── server.js                  # Server for the demo -> visit localhost:3030
│
└── index.html                 # UI + canvas preview + training controls

How It Works

1. Load MNIST

Images (28×28 float) and labels (uint32) are uploaded into storage buffers.

2. Forward Pass

forward_dense_layer1.wgsl

  • matrix multiply: X·W₁ + b₁
  • activation: ReLU

forward_dense_layer2.wgsl

  • matrix multiply: H·W₂ + b₂

3. Loss + Gradient w.r.t. logits

softmax_cross_entropy_backward.wgsl Produces:

  • gradientLogits[b * numClasses + c]
  • crossEntropyLoss[b]

4. Gradients for Dense Layer 2

gradients_dense_layer2.wgsl Computes:

  • dW₂
  • dB₂
  • dHidden (backprop into hidden layer)

5. Gradients for Dense Layer 1

gradients_dense_layer1.wgsl Computes:

  • dW₁
  • dB₁

6. SGD Update

sgd_optimizer_update.wgsl Updates all weights and biases:

w -= learningRate * dw
b -= learningRate * db

Running the Project

Requirements

  • A browser with WebGPU support (Chrome ≥ 113, Firefox Nightly, Edge)
  • Local HTTP server (Chrome blocks file:// GPU access)

Start a server

Example:

node server.js

or

python3 -m http.server

or

npx http-server .

Then open:

http://localhost:3030

Press Train — the network trains in real time, showing:

  • Loss
  • Accuracy
  • Six random preview samples with predicted labels

Why This Project Matters

This repository demonstrates:

  • understanding of GPU buffer layouts
  • compute pipeline dispatching
  • custom WGSL kernels for ML
  • manual backprop and gradient math
  • GPU memory management and buffer resizing
  • a full deep-learning training loop without any frameworks

It is strong evidence of skill in:

AI systems engineering, WebGPU compute design, neural-network internals, and performance-oriented GPU development.

These skills are directly relevant for:

  • AI inference runtimes
  • GPU acceleration teams
  • Web-based model execution
  • systems-level optimization roles

License

MIT

Description
Designed and implemented a GPU-accelerated MNIST training engine fully in WebGPU. Built complete forward and backward passes using 6 custom WGSL compute shaders: Dense layer forward kernels ReLU activation Softmax + cross-entropy backward Dense layer gradient kernels SGD optimizer Implemented GPU memory layouts, buffer management, shader orchestration, and dispatch scaling. Achieved ~40% training accuracy after 20 epochs on batch 64 in browser. Demonstrates ability to build GPU compute pipelines and AI runtimes from low-level operations.
Readme 12 MiB
Languages
JavaScript 91.2%
WGSL 7.4%
HTML 1.4%