Skip to content

Commit

Permalink
fix some issues with jslBERT + final weights
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 18, 2023
1 parent 27eaf58 commit 2922a5c
Show file tree
Hide file tree
Showing 3 changed files with 678 additions and 550 deletions.
47 changes: 33 additions & 14 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ const Activation = {
const s = X.map(x => Math.exp(x - max));
const sum = s.reduce((c, e) => c + e);
return s.map(e => e/sum);
},
masked_softmax: (X, M) => {
const a = - 1000000000;
// const max = Math.max(...X);
const s = X.map((x, i) => Math.exp(x + (1-M[i])*a));
const sum = s.reduce((c, e) => c + e);
return s.map(e => e/sum);
}
};

Expand Down Expand Up @@ -70,10 +77,10 @@ const vsum = (X) => X.reduce((c, e) => c + e);
// return X.map(e => (e - mean)/sigma);
// };

const norm_param = (X, beta, gamma) => {
const norm_param = (X, beta, gamma, epsilon=0.00001) => {
const mean = vsum(X)/X.length;
const sigma = Math.sqrt(X.reduce((c, e) => c + Math.pow(e - mean, 2), 0)/X.length);
return X.map(e => beta + (gamma * (e - mean)/sigma));
const sigma = Math.sqrt(epsilon + X.reduce((c, e) => c + Math.pow(e - mean, 2), 0)/X.length);
return X.map((e, i) => beta[i] + (gamma[i] * (e - mean)/sigma));
};

const transpose = (X) => X[0].map((col, i) => X.map(row => row[i]));
Expand Down Expand Up @@ -127,7 +134,7 @@ class LayerNormalization {
}

predict(X){// [[...], [...], ...]
return X.map((x, i) => norm_param(x, this.beta[i], this.gamma[i]), this);
return X.map((x, i) => norm_param(x, this.beta, this.gamma, this.epsilon), this);
}
}

Expand All @@ -144,7 +151,7 @@ class JslBERTBlock {
this.ln2 = ln2;
}

predict(Q, K, V, M){ //vectors [[WORD-ENC], [], ...]
predict(Q, K, V, Mq=null, Mv=null){ //vectors [[WORD-ENC], [], ...]

let Q_enc_head = []; // h X n X d
let K_enc_head = []; // h X m X d
Expand All @@ -162,37 +169,47 @@ class JslBERTBlock {

let result = [];

Q_enc_head.forEach(Qi => { //over n target words
Q_enc_head.forEach((Qi, i) => { //over n target words
let headi = [];
// if (Mq[i]){
for(let h = 0; h < this.encoders.length; h++){//over h heads
let perc = [];
K_enc_head.forEach((Kj, j) => {//over m source words
perc.push(scaledot(Qi[h], Kj[h]));//* M[j]
// if (Mv[j]) perc.push(scaledot(Qi[h], Kj[h]));//* M[j]
// else perc.push(0);
perc.push(scaledot(Qi[h], Kj[h]));
}, this);
perc = Activation.softmax(perc);
perc = Activation.softmax(perc, Mv);

const Ri = V_enc_head.reduce((C, Vj, j) => vplus2(C, svmul(perc[j], Vj[h])),
zeros(V_enc_head[0][0].length));

// let Ri = zeros(V_enc_head[0][0].length);
// V_enc_head.forEach((Vj, j) => {//over m source words
// Ri = vplus2(Ri, svmul(perc[j], Vj[h]))
// }, this);
headi = headi.concat(Ri);
}
// }else{
// headi = zeros(V_enc_head[0][0].length * this.encoders.length);
// }

result.push(headi);

}, this);

result = this.hp.predict_all(result);

// console.log("V", V);

// console.log("lma", result);

//add and norm
result = result.map((r, i) => vplus2(r, V[i]));
result = this.ln1.predict(result);
// console.log("addnorm1", result);
//Feed-Forward
result = this.ffp.predict_all(result);
// console.log("FFP", result);
//add and norm
result = result.map((r, i) => vplus2(r, V[i]));
// console.log("ADD", result);
return this.ln2.predict(result);
}
}
Expand All @@ -208,7 +225,7 @@ class JslBERT {
this.blocks = blocks;
}

predict(X){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]
predict(X, M=null){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]

//incoding inputs
let X_enc = [];
Expand All @@ -220,11 +237,13 @@ class JslBERT {
X_enc.push(x_emb);
}, this);

// console.log("Word", X_enc);

//the mask, when the token is all zeros then 0; otherwise 1
//let M = X.map(x => vsum(x[0]) > 0? 1 : 0);

this.blocks.forEach(block => {
X_enc = block.predict(X_enc, X_enc, X_enc);
X_enc = block.predict(X_enc, X_enc, X_enc, M, M);
}, this);

return X_enc;
Expand Down
Loading

0 comments on commit 2922a5c

Please sign in to comment.