Skip to content

Commit

Permalink
add BERT like class
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 2, 2023
1 parent 2bbb60c commit fe377c7
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 31 deletions.
125 changes: 118 additions & 7 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* @module jslml
*/

import { NormalModule } from "webpack";


//==========================================
// ACTIVATION FUNCTIONS
Expand All @@ -11,6 +13,8 @@
const Activation = {
tanh: (X) => X.map(x => 2/(1 + Math.exp(-2*x)) - 1),
sigmoid: (X) => X.map(x => 1/(1 + Math.exp(-x))),
linear: (X) => X,
relu: (X) => X.map(x => x < 0 ? 0 : x),
softmax: (X) => {
const max = Math.max(...X);
const s = X.map(x => Math.exp(x - max));
Expand All @@ -28,6 +32,10 @@ const Activation = {
*/
const dot = (X, Y) => X.map((x, i) => x * Y[i]).reduce((e0, e1) => e0 + e1);

const matmul = (X, Y) => X.map((x, i) => x * Y[i]);

const scaledot = (X, Y) => dot(X, Y)/Math.sqrt(X.length);

/**
*
* @param {*} X
Expand All @@ -39,12 +47,26 @@ const zip = (X, Y) => X.map((x, i) => [x, Y[i]]);

const get_k_max = (X, k) => X.sort(([k1, v1], [k2, v2]) => v2 - v1).slice(0, k);

const vplus3 = (X, Y, Z) => X.map((x, i) => x + Y[i] + Z[i]);

const vplus2 = (X, Y) => X.map((x, i) => x + Y[i]);

const vsum = (X) => X.reduce((c, e) => c + e);

const norm = (X) => {
const mean = vsum(X)/X.length;
const sigma = Math.sqrt(X.reduce((c, e) => c + Math.pow(e - mean, 2), 0)/X.length);
return X.map(e => (e - mu)/sigma);
};

const transpose = (X) => X[0].map((col, i) => X.map(row => row[i]));


//==========================================
// NEURON API
// NEURAL API
//==========================================

class Neuron {
class Perceptron {

constructor(w, b, activate=Activation.sigmoid, cls_names=[], th=0.5){
this.w = w;
Expand Down Expand Up @@ -80,17 +102,106 @@ class Neuron {

}

class JslBERTBlock {
/**
*
* @param {Perceptron[[]]} encoders h head in each head 3 Perceptrons (Q, K, V)
*/
constructor(encoders, hp, ffp){
this.encoders = encoders;
this.hp = hp;
this.ffp = ffp;
}

predict(Q, K, V, M){ //vectors [[WORD-ENC], [], ...]

let Q_enc_head = []; // h X n X d
let K_enc_head = []; // h X m X d
let V_enc_head = []; // h X m X d

this.encoders.forEach(QKVenc => {
Q_enc_head.add(QKVenc[0].predict_all(Q));
K_enc_head.add(QKVenc[1].predict_all(K));
V_enc_head.add(QKVenc[2].predict_all(V));
}, this);

Q_enc_head = transpose(Q_enc_head); //n X h X d
K_enc_head = transpose(K_enc_head); //m X h X d
V_enc_head = transpose(V_enc_head); //m X h X d

let result = [];

Q_enc_head.forEach(Qi => { //over n target words
let headi = [];
for(let h = 0; h < this.encoders.length; h++){//over h heads
const perc = [];
K_enc_head.forEach((Kj, j) => {//over m source words
perc.push(scaledot(Qi[h], Kj[h]) * M[j]);
}, this);
perc = softmax(perc);
let Ri = matmul(perc, Vj[h]);
headi = headi.concat(Ri);
}
result.push(headi);

}, this);

result = this.hp.predict_all(result);

//add and norm
result = result.map((r, i) => norm(vplus2(r, V[i])));
//Feed-Forward
result = this.ffp.predict_all(result);
//add and norm
return result.map((r, i) => norm(vplus2(r, V[i])));
}
}

class JslBERT {
/**
*
* @param {Perceptron[]} encoders encoders for inputs embedding
* @param {JslBERTBlock[]} blocks stacked blocks for incoding
*/
constructor(encoders, blocks){
this.encoders = encoders;
this.blocks = blocks;
}

predict(X){// X = [[[TOKEN], [POSITION], [SEGMENT]], [[], [], []], ...]

//incoding inputs
let X_enc = [];
X.forEach((word, i) =>{
let tok_emb = this.encoders[0].predict(word[0]);
let pos_emb = this.encoders[1].predict(word[1]);
let seg_emb = this.encoders[2].predict(word[2]);
let word_emb = vplus3(tok_emb, pos_emb, seg_emb);
X_enc.push(word_emb);
}, this);

//the mask, when the token is all zeros then 0; otherwise 1
let M = X.map(x => vsum(x[0]) > 0? 1 : 0);

this.blocks.forEach(block => {
X_enc = block.predict(X_enc, X_enc, X_enc, M);
}, this);

return X_enc;
}
}


//==========================================
// SEQUENCE TAGGING API
//==========================================

class EmptyNeuron{
predict(x){return x;}
}
// class EmptyNeuron{
// predict(x){return x;}
// }

class TagEncoder{
constructor(tag_list, embedding=new EmptyNeuron()){
constructor(tag_list, embedding={predict: x => x}){//new EmptyNeuron()
this.tag_list = tag_list;
this.embedding = embedding;
}
Expand Down Expand Up @@ -145,4 +256,4 @@ class BeamMEMM {
}


export default MaxEnt;
export {Activation, Perceptron, TagEncoder, BeamMEMM, JslBERTBlock, JslBERT};
2 changes: 1 addition & 1 deletion src/eng/eng.morpho.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ function __verbTypes(verb) {


EngMorpho._nStem("porter", "English Porter stemmer", __porterStemmer);
EngMorpho._nStem("lancaster", "English Lnacaster stemmer", __lancasterStemmer);
EngMorpho._nStem("lancaster", "English Lancaster stemmer", __lancasterStemmer);

EngMorpho._nConv("sing2pl", "Singular noun to Plural", __singular2plural);

Expand Down
26 changes: 26 additions & 0 deletions src/eng/eng.syntax.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import Syntax from "../syntax.mjs";
import {Activation, Perceptron, BeamMEMM} from "../_jslml.mjs";

class EngSyntax extends Syntax {
static list_tags = [];
static w_memm = [];
static b_memm = [];
static maxent = new Perceptron(w_memm, b_memm, Activation.softmax, this.list_tags);
static memm = new BeamMEMM(5, this.maxent);

/**
* Encoding all sentence words
*
* @protected
* @final
* @static
* @param {String[]} sentence list of words in sentence
* @return {float[[]]} encoding of each word
*/
static _words_encode(sentence){
return [[]];
}

}

export default EngSyntax;
16 changes: 0 additions & 16 deletions src/jslingua.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,6 @@ class JsLingua {
return service[langCode];
}

/**
* Get an object of a service class for a given language and service name.<br>
* For example: JsLingua.nserv("Info", "ara") Gives an object of the class AraInfo
*
* @public
* @static
* @param {String} serviceID The name of the service (the super-classe): "Info", "Lang", etc.
* @param {String} langCode The language ISO639-2 code: "ara", "jpn", "eng", etc.
* @return {Class} The class that affords the service
*/
static nserv(serviceID, langCode) {
let Cls = this.gserv(serviceID, langCode);
if (Cls === null) return null;
return new Cls();
}

/**
* Returns the version of JsLingua
*
Expand Down
4 changes: 2 additions & 2 deletions src/morpho.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class Morpho {
* person: "first", // Morpho.Feature.Person.F
* number: "singular" // Morpho.Feature.Number.S
* };
* var I = goptname("Pronoun", opts);
* var I = gconjoptname("Pronoun", opts);
* // In English, it will give: "I"
* // In Arabic, it will give: "أنا"
*
Expand All @@ -591,7 +591,7 @@ class Morpho {
* @param {Object} opts The parameters
* @return {String} The label of this parameter in the current language
*/
static goptname(optLabel, opts){
static gconjoptname(optLabel, opts){
switch (optLabel) {
case "Pronoun": return this._gPpName(opts);
case "Negation": return this._gNegName(opts);
Expand Down
12 changes: 7 additions & 5 deletions src/syntax.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

import JslNode from "_jslgraph.mjs";
//import {Activation, Neuron, TagEncoder, BeamMEMM} from "_jslml.mjs";

/**
* Syntactic functions such as PoS tagging, etc.
Expand Down Expand Up @@ -122,10 +123,10 @@ class Syntax {
* @static
* @param {int} i current position
* @param {String[]} sentence list of words in sentence
* @return {float[]} encoding of the current word
* @return {float[[]]} encoding of each word
*/
static _word_encode(i, sentence){
return [];
static _words_encode(sentence){
return [[]];
}

/**
Expand Down Expand Up @@ -170,9 +171,10 @@ class Syntax {
* @return {String[]} list of tags of these words
*/
static pos_tag(sentence){
this.memm.init(this._word_encode(0, sentence));
let encoded = this._words_encode(sentence);
this.memm.init(encoded[0]);
for(let i = 1; i < sentence.length; i++){
this.memm.step(this._word_encode(i, sentence));
this.memm.step(encoded[i]);
}
return this.memm.final();
}
Expand Down

0 comments on commit fe377c7

Please sign in to comment.