Skip to content

Commit

Permalink
Semantic module
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 14, 2023
1 parent 7fbdebc commit 27eaf58
Show file tree
Hide file tree
Showing 8 changed files with 772 additions and 221 deletions.
96 changes: 68 additions & 28 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
* @module jslml
*/

import { NormalModule } from "webpack";


//==========================================
// ACTIVATION FUNCTIONS
Expand Down Expand Up @@ -32,10 +30,14 @@ const Activation = {
*/
const dot = (X, Y) => X.map((x, i) => x * Y[i]).reduce((e0, e1) => e0 + e1);

const matmul = (X, Y) => X.map((x, i) => x * Y[i]);
// const matmul = (X, Y) => X.map((x, i) => x * Y[i]);

const svmul = (s, V) => V.map((v, i) => v * s);

const scaledot = (X, Y) => dot(X, Y)/Math.sqrt(X.length);

const zeros = N => new Array(N).fill(0);

/**
*
* @param {*} X
Expand All @@ -51,12 +53,27 @@ const vplus3 = (X, Y, Z) => X.map((x, i) => x + Y[i] + Z[i]);

const vplus2 = (X, Y) => X.map((x, i) => x + Y[i]);

// const vminus2 = (X, Y) => X.map((x, i) => x - Y[i]);

const vsum = (X) => X.reduce((c, e) => c + e);

const norm = (X) => {
// const mean2D = X => X.reduce((c, x) => vplus2(c, x), zeros(X[0].length)).map(x => x /X.length);

// const vminuspow2 = (X, Y) => X.map((x, i) => Math.pow(x - Y[i], 2));

// const sigma2D = (X, Mu, eps) => X.reduce((c, x) => vplus2(c, vminuspow2(x, Mu)), zeros(X[0].length))
// .map(x => Math.sqrt(eps + x/X.length));

// const norm = (X) => {
// const mean = vsum(X)/X.length;
// const sigma = Math.sqrt(X.reduce((c, e) => c + Math.pow(e - mean, 2), 0)/X.length);
// return X.map(e => (e - mean)/sigma);
// };

const norm_param = (X, beta, gamma) => {
const mean = vsum(X)/X.length;
const sigma = Math.sqrt(X.reduce((c, e) => c + Math.pow(e - mean, 2), 0)/X.length);
return X.map(e => (e - mu)/sigma);
return X.map(e => beta + (gamma * (e - mean)/sigma));
};

const transpose = (X) => X[0].map((col, i) => X.map(row => row[i]));
Expand All @@ -68,7 +85,7 @@ const transpose = (X) => X[0].map((col, i) => X.map(row => row[i]));

class Perceptron {

constructor(weights, bias, activate=Activation.sigmoid, cls_names=[], th=0.5){
constructor(weights, bias, activate=null, cls_names=[], th=0.5){
this.weights = weights;
this.bias = bias;
this.activate = activate;
Expand All @@ -85,13 +102,13 @@ class Perceptron {
let cls = 0;
if(this.muliclass){
cls = [];
for(let i = 0; i < this.w.length; i++) cls.push(dot(this.weights[i], x) + this.bias[i]);
for(let i = 0; i < this.weights.length; i++) cls.push(dot(this.weights[i], x) + this.bias[i]);
}
else { //binary
cls = dot(this.weights, x) + this.bias;
}

cls = this.activate(cls);
if (this.activate) cls = this.activate(cls);
return prob ? cls : this.get_class(cls);
}

Expand All @@ -102,15 +119,29 @@ class Perceptron {

}

class LayerNormalization {
constructor(beta, gamma, epsilon=0.001){
this.beta = beta;
this.gamma = gamma;
this.epsilon = epsilon;
}

predict(X){// [[...], [...], ...]
return X.map((x, i) => norm_param(x, this.beta[i], this.gamma[i]), this);
}
}

class JslBERTBlock {
/**
*
* @param {Perceptron[[]]} encoders h head in each head 3 Perceptrons (Q, K, V)
*/
constructor(encoders, hp, ffp){
constructor(encoders, hp, ffp, ln1, ln2){
this.encoders = encoders;
this.hp = hp;
this.ffp = ffp;
this.ln1 = ln1;
this.ln2 = ln2;
}

predict(Q, K, V, M){ //vectors [[WORD-ENC], [], ...]
Expand All @@ -120,9 +151,9 @@ class JslBERTBlock {
let V_enc_head = []; // h X m X d

this.encoders.forEach(QKVenc => {
Q_enc_head.add(QKVenc[0].predict_all(Q));
K_enc_head.add(QKVenc[1].predict_all(K));
V_enc_head.add(QKVenc[2].predict_all(V));
Q_enc_head.push(QKVenc[0].predict_all(Q));
K_enc_head.push(QKVenc[1].predict_all(K));
V_enc_head.push(QKVenc[2].predict_all(V));
}, this);

Q_enc_head = transpose(Q_enc_head); //n X h X d
Expand All @@ -134,12 +165,19 @@ class JslBERTBlock {
Q_enc_head.forEach(Qi => { //over n target words
let headi = [];
for(let h = 0; h < this.encoders.length; h++){//over h heads
const perc = [];
let perc = [];
K_enc_head.forEach((Kj, j) => {//over m source words
perc.push(scaledot(Qi[h], Kj[h]) * M[j]);
perc.push(scaledot(Qi[h], Kj[h]));//* M[j]
}, this);
perc = softmax(perc);
let Ri = matmul(perc, Vj[h]);
perc = Activation.softmax(perc);

const Ri = V_enc_head.reduce((C, Vj, j) => vplus2(C, svmul(perc[j], Vj[h])),
zeros(V_enc_head[0][0].length));

// let Ri = zeros(V_enc_head[0][0].length);
// V_enc_head.forEach((Vj, j) => {//over m source words
// Ri = vplus2(Ri, svmul(perc[j], Vj[h]))
// }, this);
headi = headi.concat(Ri);
}
result.push(headi);
Expand All @@ -149,18 +187,20 @@ class JslBERTBlock {
result = this.hp.predict_all(result);

//add and norm
result = result.map((r, i) => norm(vplus2(r, V[i])));
result = result.map((r, i) => vplus2(r, V[i]));
result = this.ln1.predict(result);
//Feed-Forward
result = this.ffp.predict_all(result);
//add and norm
return result.map((r, i) => norm(vplus2(r, V[i])));
result = result.map((r, i) => vplus2(r, V[i]));
return this.ln2.predict(result);
}
}

class JslBERT {
/**
*
* @param {Perceptron[]} encoders encoders for inputs embedding
* @param {Perceptron[]} encoders encoders for inputs embedding (token, position, segment)
* @param {JslBERTBlock[]} blocks stacked blocks for incoding
*/
constructor(encoders, blocks){
Expand All @@ -172,19 +212,19 @@ class JslBERT {

//incoding inputs
let X_enc = [];
X.forEach((word, i) =>{
let tok_emb = this.encoders[0].predict(word[0]);
let pos_emb = this.encoders[1].predict(word[1]);
let seg_emb = this.encoders[2].predict(word[2]);
let word_emb = vplus3(tok_emb, pos_emb, seg_emb);
X_enc.push(word_emb);
X.forEach((x, i) =>{
let tok_emb = this.encoders[0].predict(x[0]);
let pos_emb = this.encoders[1].predict(x[1]);
let seg_emb = this.encoders[2].predict(x[2]);
let x_emb = vplus3(tok_emb, pos_emb, seg_emb);
X_enc.push(x_emb);
}, this);

//the mask, when the token is all zeros then 0; otherwise 1
let M = X.map(x => vsum(x[0]) > 0? 1 : 0);
//let M = X.map(x => vsum(x[0]) > 0? 1 : 0);

this.blocks.forEach(block => {
X_enc = block.predict(X_enc, X_enc, X_enc, M);
X_enc = block.predict(X_enc, X_enc, X_enc);
}, this);

return X_enc;
Expand Down Expand Up @@ -256,4 +296,4 @@ class BeamMEMM {
}


export {Activation, Perceptron, TagEncoder, BeamMEMM, JslBERTBlock, JslBERT};
export {Activation, LayerNormalization, Perceptron, TagEncoder, BeamMEMM, JslBERTBlock, JslBERT};
Loading

0 comments on commit 27eaf58

Please sign in to comment.