Skip to content

Commit

Permalink
add masks into jslBERT
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 19, 2023
1 parent 2922a5c commit 00fae8e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
50 changes: 32 additions & 18 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@ const Activation = {
return s.map(e => e/sum);
},
masked_softmax: (X, M) => {
const a = - 1000000000;
// const max = Math.max(...X);
const s = X.map((x, i) => Math.exp(x + (1-M[i])*a));
// const a = - 1000000000;
// // const max = Math.max(...X);
// const s = X.map((x, i) => Math.exp(x + (1-M[i])*a));
// const sum = s.reduce((c, e) => c + e);
// return s.map(e => e/sum);

const max = Math.max(...X);
const s = X.map((x, i) => M[i]? Math.exp(x-max): 0);
const sum = s.reduce((c, e) => c + e);
return s.map(e => e/sum);
}
Expand All @@ -45,6 +50,10 @@ const scaledot = (X, Y) => dot(X, Y)/Math.sqrt(X.length);

const zeros = N => new Array(N).fill(0);

const uniform = N => new Array(N).fill(1/N);

const full = (N, x) => new Array(N).fill(x);

/**
*
* @param {*} X
Expand Down Expand Up @@ -151,8 +160,11 @@ class JslBERTBlock {
this.ln2 = ln2;
}

predict(Q, K, V, Mq=null, Mv=null){ //vectors [[WORD-ENC], [], ...]

predict(Q, K, V, Mv, Mq){ //vectors [[WORD-ENC], [], ...]

const Mvdef = Array.isArray(Mv) && (Mv.length == V.length),
Mqdef = Array.isArray(Mq) && (Mq.length == Q.length);

let Q_enc_head = []; // h X n X d
let K_enc_head = []; // h X m X d
let V_enc_head = []; // h X m X d
Expand All @@ -171,24 +183,26 @@ class JslBERTBlock {

Q_enc_head.forEach((Qi, i) => { //over n target words
let headi = [];
// if (Mq[i]){
for(let h = 0; h < this.encoders.length; h++){//over h heads
let perc = [];
K_enc_head.forEach((Kj, j) => {//over m source words
// if (Mv[j]) perc.push(scaledot(Qi[h], Kj[h]));//* M[j]
// else perc.push(0);
perc.push(scaledot(Qi[h], Kj[h]));
}, this);
perc = Activation.softmax(perc, Mv);
if (!Mqdef || Mq[i]){
K_enc_head.forEach((Kj, j) => {//over m source words
if (Mvdef && !Mv[j]) perc.push(0)
else perc.push(scaledot(Qi[h], Kj[h]));
}, this);
if (Mvdef) perc = Activation.masked_softmax(perc, Mv);
else perc = Activation.softmax(perc);
}else{//if no attention for this query, we add uniform attention
perc = uniform(K_enc_head.length);
}

// console.log(perc);

const Ri = V_enc_head.reduce((C, Vj, j) => vplus2(C, svmul(perc[j], Vj[h])),
zeros(V_enc_head[0][0].length));

headi = headi.concat(Ri);
}
// }else{
// headi = zeros(V_enc_head[0][0].length * this.encoders.length);
// }

result.push(headi);

Expand Down Expand Up @@ -225,7 +239,7 @@ class JslBERT {
this.blocks = blocks;
}

predict(X, M=null){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]
predict(X, masked=false){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]

//incoding inputs
let X_enc = [];
Expand All @@ -239,8 +253,8 @@ class JslBERT {

// console.log("Word", X_enc);

//the mask, when the token is all zeros then 0; otherwise 1
//let M = X.map(x => vsum(x[0]) > 0? 1 : 0);
//the mask, when the token starts with 1 then 0; otherwise 1
const M = masked? X.map(x => 1 - x[0]): undefined;

this.blocks.forEach(block => {
X_enc = block.predict(X_enc, X_enc, X_enc, M, M);
Expand Down
10 changes: 5 additions & 5 deletions test/nodejs/test_embedding.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ const X = [
[[1,0,0,0,0,0,0], [0,0,0,0,0,1], [1, 0]]
]

// const Mask = [1,1,1,1,0,0]
const Mask = [1,1,1,1,0,0]

const wemb = brt.predict(X);
const wemb = brt.predict(X, Mask);

// console.log("EMB", wemb);
console.log("EMB", wemb);


// const p = new Perceptron([[1, 2, 3], [4, 5, 6]], [1, -1]);
Expand All @@ -106,5 +106,5 @@ const wemb = brt.predict(X);

// console.log(EngSem);
// console.log(EngSem.__word2BERTCodes("cat"));
console.log(EngSem.word_embedding("cat"));
console.log(EngSem.word_embedding("Dog"));
// console.log(EngSem.word_embedding("cat"));
// console.log(EngSem.word_embedding("Dog"));

0 comments on commit 00fae8e

Please sign in to comment.