Skip to content

Commit

Permalink
new weights with masks over padding
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 24, 2023
1 parent 00fae8e commit db87ec3
Show file tree
Hide file tree
Showing 3 changed files with 668 additions and 657 deletions.
12 changes: 8 additions & 4 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class JslBERTBlock {
}, this);
if (Mvdef) perc = Activation.masked_softmax(perc, Mv);
else perc = Activation.softmax(perc);
}else{//if no attention for this query, we add uniform attention
}else{//if no attention for this query, we add a uniform attention
perc = uniform(K_enc_head.length);
}

Expand Down Expand Up @@ -234,12 +234,13 @@ class JslBERT {
* @param {Perceptron[]} encoders encoders for inputs embedding (token, position, segment)
* @param {JslBERTBlock[]} blocks stacked blocks for incoding
*/
constructor(encoders, blocks){
constructor(encoders, blocks, masked=false){
this.encoders = encoders;
this.blocks = blocks;
this.masked = masked;
}

predict(X, masked=false){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]
predict(X){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]

//incoding inputs
let X_enc = [];
Expand All @@ -254,7 +255,9 @@ class JslBERT {
// console.log("Word", X_enc);

//the mask, when the token starts with 1 then 0; otherwise 1
const M = masked? X.map(x => 1 - x[0]): undefined;
const M = this.masked? X.map(x => 1 - x[0][0]): undefined;

// console.log("Mask", M);

this.blocks.forEach(block => {
X_enc = block.predict(X_enc, X_enc, X_enc, M, M);
Expand All @@ -274,6 +277,7 @@ class JslBERT {
// }

class TagEncoder{

constructor(tag_list, embedding={predict: x => x}){//new EmptyNeuron()
this.tag_list = tag_list;
this.embedding = embedding;
Expand Down
Loading

0 comments on commit db87ec3

Please sign in to comment.