Files
Simple-Transformer-JS/real.js

536 lines
13 KiB
JavaScript
Raw Normal View History

2025-11-18 13:16:41 +01:00
import Vector from "./Vector.js"
import Matrix from "./Matrix.js"
function createRandomNumberGenerator() {
let seed = 12345;
function generateNextRandomNumber() {
seed = ( seed * 16807 ) % 2147483647;
return ( seed - 1 ) / 2147483646;
}
return generateNextRandomNumber;
}
const randomNumberGenerator = createRandomNumberGenerator();
class SimpleEmbedding {
constructor( vocabularySize, embeddingDimension ) {
this.vocabularySize = vocabularySize;
this.embeddingDimension = embeddingDimension;
this.embeddingVectors = new Array( vocabularySize );
this.initializeEmbeddings();
}
initializeEmbeddings() {
for ( let index = 0 ; index < this.vocabularySize ; index++ ) {
const vectorInstance = new Vector( this.embeddingDimension );
this.initializeVectorRandomly( vectorInstance );
this.embeddingVectors[ index ] = vectorInstance;
}
}
initializeVectorRandomly( vectorInstance ) {
for ( let elementIndex = 0 ; elementIndex < vectorInstance.data.length ; elementIndex++ ) {
vectorInstance.data[ elementIndex ] = ( randomNumberGenerator() - 0.5 ) * 0.1;
}
}
lookupEmbedding( tokenIdentifier ) {
return this.embeddingVectors[ tokenIdentifier ];
}
updateEmbedding( tokenIdentifier, gradientVector, learningRate ) {
const vectorInstance = this.embeddingVectors[ tokenIdentifier ];
for ( let elementIndex = 0 ; elementIndex < this.embeddingDimension ; elementIndex++ ) {
vectorInstance.data[ elementIndex ] -= learningRate * gradientVector.data[ elementIndex ];
}
}
}
class PositionalEncoding {
static calculateValue( position, index ) {
const modelDimension = 8;
const angleRate = 1 / Math.pow( 10000, 2 * Math.floor( index / 2 ) / modelDimension );
if ( index % 2 === 0 ) {
return Math.sin( position * angleRate );
}
return Math.cos( position * angleRate );
}
}
function calculateCrossEntropyLoss( probabilities, targetIndex ) {
return -Math.log( probabilities.data[ targetIndex ] + 1e-9 );
}
function calculateCrossEntropyGradient( probabilities, targetIndex ) {
const gradientVector = new Vector( probabilities.length );
for ( let index = 0 ; index < probabilities.length ; index++ ) {
gradientVector.data[ index ] = probabilities.data[ index ];
}
gradientVector.data[ targetIndex ] -= 1;
return gradientVector;
}
const vocabulary = {
The: 0,
Cat: 1,
Sat: 2,
On: 3,
Mat: 4,
Bench: 5,
Book: 6,
Great: 7,
Is: 8
};
const vocabularySize = Object.keys( vocabulary ).length;
const embeddingDimension = 8;
const embeddingsInstance = new SimpleEmbedding( vocabularySize, embeddingDimension );
const matrixQuery = Matrix.random( embeddingDimension, embeddingDimension );
const matrixKey = Matrix.random( embeddingDimension, embeddingDimension );
const matrixValue = Matrix.random( embeddingDimension, embeddingDimension );
const matrixOutput = Matrix.random( vocabularySize, embeddingDimension );
function applyPositionalEncodingToInputEmbeddings( inputEmbeddingVectors ) {
for ( let positionIndex = 0 ; positionIndex < inputEmbeddingVectors.length ; positionIndex++ ) {
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
inputEmbeddingVectors[ positionIndex ][ dimensionIndex ] += PositionalEncoding.calculateValue( positionIndex, dimensionIndex );
}
}
}
function computeAttentionScoresVector( queryVector, keyVectors ) {
const scoresVector = new Vector( keyVectors.length );
for ( let index = 0 ; index < keyVectors.length ; index++ ) {
scoresVector.data[ index ] = Vector.dot( queryVector, keyVectors[ index ] ) / Math.sqrt( embeddingDimension );
}
return scoresVector;
}
function computeWeightedSumVector( attentionWeightsVector, valueVectors ) {
const weightedSumVector = new Vector( embeddingDimension );
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
let sum = 0;
for ( let index = 0 ; index < valueVectors.length ; index++ ) {
sum += attentionWeightsVector.data[ index ] * valueVectors[ index ].data[ dimensionIndex ];
}
weightedSumVector.data[ dimensionIndex ] = sum;
}
return weightedSumVector;
}
function computeLogitsVector( weightedSumVector ) {
const logitsVector = new Vector( vocabularySize );
for ( let vocabIndex = 0 ; vocabIndex < vocabularySize ; vocabIndex++ ) {
let sum = 0;
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
sum += matrixOutput.data[ vocabIndex ][ dimensionIndex ] * weightedSumVector.data[ dimensionIndex ];
}
logitsVector.data[ vocabIndex ] = sum;
}
return logitsVector;
}
function forwardPass( inputTokenIdentifiers ) {
const inputEmbeddingVectors = new Array( inputTokenIdentifiers.length );
for ( let index = 0 ; index < inputTokenIdentifiers.length ; index++ ) {
inputEmbeddingVectors[ index ] = embeddingsInstance.lookupEmbedding( inputTokenIdentifiers[ index ] );
}
applyPositionalEncodingToInputEmbeddings( inputEmbeddingVectors );
const queryVectors = new Array( inputEmbeddingVectors.length );
const keyVectors = new Array( inputEmbeddingVectors.length );
const valueVectors = new Array( inputEmbeddingVectors.length );
for ( let index = 0 ; index < inputEmbeddingVectors.length ; index++ ) {
queryVectors[ index ] = Matrix.matVecMul( matrixQuery, inputEmbeddingVectors[ index ] );
keyVectors[ index ] = Matrix.matVecMul( matrixKey, inputEmbeddingVectors[ index ] );
valueVectors[ index ] = Matrix.matVecMul( matrixValue, inputEmbeddingVectors[ index ] );
}
const lastQueryVector = queryVectors[ queryVectors.length - 1 ];
const attentionScores = computeAttentionScoresVector( lastQueryVector, keyVectors );
const attentionWeights = Vector.softmax( attentionScores );
const weightedSumVector = computeWeightedSumVector( attentionWeights, valueVectors );
const logitsVector = computeLogitsVector( weightedSumVector );
const probabilities = Vector.softmax( logitsVector );
const resultObject = {
probabilities: probabilities,
attentionWeights: attentionWeights,
weightedSumVector: weightedSumVector,
queryVectors: queryVectors,
keyVectors: keyVectors,
valueVectors: valueVectors,
inputEmbeddingVectors: inputEmbeddingVectors,
lastQueryVector: lastQueryVector
};
return resultObject;
}
function updateOutputLayerWeights( probabilities, targetIndex, weightedSumVector, learningRate ) {
const gradientLossVector = calculateCrossEntropyGradient( probabilities, targetIndex );
for ( let vocabIndex = 0 ; vocabIndex < vocabularySize ; vocabIndex++ ) {
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
matrixOutput.data[ vocabIndex ][ dimensionIndex ] -= learningRate * gradientLossVector.data[ vocabIndex ] * weightedSumVector.data[ dimensionIndex ];
}
}
}
function trainModel( inputTokenIdentifiers, targetIndex, learningRate ) {
const {
probabilities,
weightedSumVector,
queryVectors,
keyVectors,
valueVectors,
inputEmbeddingVectors,
lastQueryVector
} = forwardPass( inputTokenIdentifiers );
const lossValue = calculateCrossEntropyLoss( probabilities, targetIndex );
updateOutputLayerWeights( probabilities, targetIndex, weightedSumVector, learningRate );
// Calculate gradient of loss with respect to weightedSumVector: dLoss/dWeightedSum = W_out^T * gradLoss
const gradientLossVector = calculateCrossEntropyGradient( probabilities, targetIndex );
const gradientWeightedSumVector = new Vector( embeddingDimension );
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
let sum = 0;
for ( let vocabIndex = 0 ; vocabIndex < vocabularySize ; vocabIndex++ ) {
sum += matrixOutput.data[ vocabIndex ][ dimensionIndex ] * gradientLossVector.data[ vocabIndex ];
}
gradientWeightedSumVector.data[ dimensionIndex ] = sum;
}
// Backpropagate to valueVectors weighted by attention weights
for ( let index = 0 ; index < valueVectors.length ; index++ ) {
const gradientVector = new Vector( embeddingDimension );
for ( let dimensionIndex = 0 ; dimensionIndex < embeddingDimension ; dimensionIndex++ ) {
// gradVec = gradWeightedSum * attentionWeights[index]
gradientVector.data[ dimensionIndex ] = gradientWeightedSumVector.data[ dimensionIndex ] * Vector.softmax( computeAttentionScoresVector( lastQueryVector, keyVectors ) ).data[ index ];
}
// Update embedding corresponding to inputTokenIdentifiers[index]
embeddingsInstance.updateEmbedding( inputTokenIdentifiers[ index ], gradientVector, learningRate );
}
// TODO: implement updates for matrixQuery, matrixKey, matrixValue similarly
return lossValue;
}
function trainOnMultipleSequences( sequenceArray, numberOfEpochs, learningRateValue ) {
for ( let currentEpoch = 0 ; currentEpoch < numberOfEpochs ; currentEpoch++ ) {
let totalLossValue = 0;
let totalStepsCount = 0;
for ( let sequenceIndex = 0 ; sequenceIndex < sequenceArray.length ; sequenceIndex++ ) {
const currentSequence = sequenceArray[ sequenceIndex ];
// For each token in sequence except the first (start from index 1)
for ( let tokenIndex = 1 ; tokenIndex < currentSequence.length ; tokenIndex++ ) {
const inputTokens = currentSequence.slice( 0, tokenIndex );
const targetTokenIndex = currentSequence[ tokenIndex ];
const lossValue = trainModel( inputTokens, targetTokenIndex, learningRateValue );
totalLossValue += lossValue;
totalStepsCount++;
}
}
if ( currentEpoch % 100 === 0 ) {
writeToPage(".training", "p", "Epoch " + currentEpoch + " Average Loss: " + ( totalLossValue / totalStepsCount ).toFixed( 4 ) );
//console.log( "Epoch " + currentEpoch + " Average Loss: " + ( totalLossValue / totalStepsCount ).toFixed( 4 ) );
scrollToBottom(".training");
}
}
}
function getTokenIdentifier( word ) {
const capitalizedWord = word.charAt( 0 ).toUpperCase() + word.slice( 1 );
if ( vocabulary[ capitalizedWord ] !== undefined ) {
return vocabulary[ capitalizedWord ];
}
throw new Error( "Word not in vocabulary: " + word );
}
function predictNextWordGivenInput( inputWordArray ) {
const tokenIdentifierArray = inputWordArray.map( getTokenIdentifier );
const {
probabilities
} = forwardPass( tokenIdentifierArray );
let maximumProbability = -Infinity;
let maximumProbabilityIndex = -1;
for ( let index = 0 ; index < probabilities.length ; index++ ) {
if ( probabilities.data[ index ] > maximumProbability ) {
maximumProbability = probabilities.data[ index ];
maximumProbabilityIndex = index;
}
}
return Object.keys( vocabulary ).find( function( key ) { return vocabulary[ key ] === maximumProbabilityIndex; } );
}
const trainingSequences = [
new Array( vocabulary.The, vocabulary.Book, vocabulary.Is, vocabulary.Great ),
new Array( vocabulary.The, vocabulary.Book, vocabulary.Is, vocabulary.Great ),
new Array( vocabulary.The, vocabulary.Cat, vocabulary.Sat, vocabulary.On, vocabulary.The, vocabulary.Mat ),
new Array( vocabulary.The, vocabulary.Cat, vocabulary.Sat, vocabulary.On, vocabulary.The, vocabulary.Bench ),
new Array( vocabulary.The, vocabulary.Cat, vocabulary.Sat, vocabulary.On, vocabulary.The, vocabulary.Bench ),
new Array( vocabulary.The, vocabulary.Cat, vocabulary.Sat, vocabulary.On, vocabulary.The, vocabulary.Bench ),
new Array( vocabulary.The, vocabulary.Cat, vocabulary.Sat, vocabulary.On, vocabulary.The, vocabulary.Mat ),
];
// Create reverse mapping from index to word
const indexToWord = {};
for (const word in vocabulary) {
if (vocabulary.hasOwnProperty(word)) {
const index = vocabulary[word];
indexToWord[index] = word;
}
}
console.log(document.querySelector(".trainingsData"));
writeToPage(".trainingsData", "h2", "Trainings data");
// Convert and write each sequence
for (const sequence of trainingSequences) {
const sequenceWords = sequence.map( index => indexToWord[index] );
writeToPage(".trainingsData", "p",sequenceWords);
}
const totalEpochsCount = 6000;
const learningRateValue = 0.01;
//writeToPage(".training", "h2", "Training " );
trainOnMultipleSequences( trainingSequences, totalEpochsCount, learningRateValue );
function scrollToBottom( selector ) {
var element = document.querySelector( selector );
element.scrollTop = element.scrollHeight;
}
function writeToPage( selector, elementType, ...args) {
var p = document.createElement( elementType );
for(var i = 0; i < args.length; i++) {
var argument = args[i]
p.innerText += argument + " ";
}
document.querySelector( selector ).appendChild( p );
}
writeToPage(".prediction", "h2", "Prediction");
writeToPage( ".prediction", "p", "Next word after 'The Book is':", predictNextWordGivenInput( [ "The", "Book", "Is" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The Cat Sat':", predictNextWordGivenInput( [ "The", "Cat", "Sat" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The Cat':", predictNextWordGivenInput( [ "The", "Cat" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The':", predictNextWordGivenInput( [ "The" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The Cat Sat On':", predictNextWordGivenInput( [ "The", "Cat", "Sat", "On" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The Cat Sat On the':", predictNextWordGivenInput( [ "The", "Cat", "Sat", "On", "The" ] ) );
writeToPage( ".prediction", "p", "Next word after 'The Book':", predictNextWordGivenInput( [ "The", "Book" ] ) );