Skip to content

Commit

Permalink
Python implementation of jslBERT for training purpose
Browse files Browse the repository at this point in the history
  • Loading branch information
kariminf committed Jan 4, 2023
1 parent fe377c7 commit 7fbdebc
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 8 deletions.
49 changes: 49 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,52 @@ docs.json

#ignore text files in helpers folder
test/helpers/*.txt

# ==============================
# PYTHON
# ==============================

*.py[cod]

# C extensions
*.so

# Packages
*.egg
*.egg-info
dist
build
eggs
parts
bin
var
sdist
develop-eggs
.installed.cfg
lib
lib64
__pycache__

# Installer logs
pip-log.txt

# Unit test / coverage reports
.coverage
.tox
nosetests.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# System
.directory
.cache

#directories
test_ress/
.pytest_cache/
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "jslingua",
"version": "0.13.0",
"version": "0.14.0",
"description": "Language processing modules",
"main": "dist/jslingua.js",
"devDependencies": {
Expand Down Expand Up @@ -36,6 +36,9 @@
"stemmer",
"toolkit",
"library",
"syntax",
"PoS tagging",
"parsing",
"front-end",
"back-end",
"pronunciation",
Expand Down
12 changes: 6 additions & 6 deletions src/_jslml.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ const transpose = (X) => X[0].map((col, i) => X.map(row => row[i]));

class Perceptron {

constructor(w, b, activate=Activation.sigmoid, cls_names=[], th=0.5){
this.w = w;
this.b = b;
constructor(weights, bias, activate=Activation.sigmoid, cls_names=[], th=0.5){
this.weights = weights;
this.bias = bias;
this.activate = activate;
this.muliclass = Array.isArray(b);
this.muliclass = Array.isArray(bias);
if (!cls_names || !cls_names.length) this.cls_names = ["Neg", "Pos"];
else this.cls_names = cls_names;
}
Expand All @@ -85,10 +85,10 @@ class Perceptron {
let cls = 0;
if(this.muliclass){
cls = [];
for(let i = 0; i < this.w.length; i++) cls.push(dot(this.w[i], x) + this.b[i]);
for(let i = 0; i < this.w.length; i++) cls.push(dot(this.weights[i], x) + this.bias[i]);
}
else { //binary
cls = dot(this.w, x) + this.b;
cls = dot(this.weights, x) + this.bias;
}

cls = this.activate(cls);
Expand Down
2 changes: 1 addition & 1 deletion src/jslingua.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* @hideconstructor
*/
class JsLingua {
static version = "0.13.0";
static version = "0.14.0";
static rtls = ["ara", "heb", "aze", "div", "kur", "per", "fas", "urd"];
static services = {};

Expand Down
140 changes: 140 additions & 0 deletions supp/ML/jslbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2023 Abdelkrime Aries <[email protected]>
#
# ---- AUTHORS ----
# 2023 Abdelkrime Aries <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras.layers import Layer, LayerNormalization, Dense, MultiHeadAttention, CategoryEncoding


class MaskedLoss(tf.keras.losses.Loss):
def __init__(self):
self.name = 'masked_loss'
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits = False,
reduction = 'none'
)

def __call__(self, y_true, y_pred):

# Calculate the loss for each item in the batch.
loss = self.loss(y_true, y_pred)

# Mask off the losses on padding.
mask = tf.cast(y_true != 0, tf.float32)
loss *= mask

# Return the total.
return tf.reduce_sum(loss)

class JslBERTBlock(Layer):
def __init__(self, d_mdl, h):
super(JslBERTBlock, self).__init__()
self.lma = MultiHeadAttention(h, key_dim=d_mdl)
self.addnorm1 = LayerNormalization()
self.addnorm2 = LayerNormalization()
self.ffp = Dense(d_mdl, name="block_out")

def call(self, Q, K, V, M):
out = self.lma(Q, K, V, M)
out = self.addnorm1(out)
out = self.ffp(out)
return self.addnorm2(out)

class JslBERT(tf.keras.Model):
def __init__(self, blocks_nbr, d_model, heads_nbr, vocab_size, max_length, mask_rate=0.2):
super(JslBERT, self).__init__()
self.tokEmb = Dense(d_model, name="Tok_embedding")
self.posEmb = Dense(d_model, name="Pos_embedding")
self.segEmb = Dense(d_model, name="Seg_embedding")
self.blocks = []
for i in range(blocks_nbr):
self.blocks.append(JslBERTBlock(d_model, heads_nbr))
self.cls = Dense(1, activation="sigmoid", name="Is_next")
self.tok = Dense(vocab_size, activation="softmax", name="Token")

self.vocab_size = vocab_size
self.max_length = max_length
self.mask_rate = mask_rate

self.PAD = 0
self.UNK = 1
self.CLS = 2
self.SEP = 3
self.MASK = 4

self.no_masking = tf.constant([self.PAD, self.CLS, self.SEP])

self.cls_loss = tf.keras.metrics.binary_crossentropy
self.tok_loss = MaskedLoss()

def _mask(self, X):
tile_shape = tf.concat([tf.ones(tf.shape(tf.shape(X)), dtype=tf.int32), tf.shape(self.no_masking)], axis=0)
X_tile = tf.tile(tf.expand_dims(X, -1), tile_shape)
Mask_mask = tf.reduce_any(tf.equal(X_tile, self.no_masking), -1)
Mask_mask = tf.logical_not(Mask_mask)

Mask_mask = tf.logical_and(Mask_mask, tf.random.uniform(tf.shape(X)) <= self.mask_rate)

return tf.where(Mask_mask, self.MASK, X)

def train_step(self, data):
# Tok, Pos, Seg, Mask, Y = data["Tok"], data["Pos"], data["Seg"], data["Mask"], data["Y"]
X, Y = data
X = tf.cast(X, tf.int32)

with tf.GradientTape() as tape:
logits = self.encode(X, train=True)
cls_logits = self.cls(logits[:, 0, :])
cls_loss = tf.reduce_sum(self.cls_loss(Y, cls_logits))

tok_logits = self.tok(logits[:, 1:, :])
tok_loss = self.tok_loss(X[:, 0, 1:], tok_logits)

loss = cls_loss + tok_loss

variables = self.trainable_variables
gradients = tape.gradient(loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))


return {"cls_loss": cls_loss, "tok_loss": tok_loss}

def encode(self, X, train=False):
Mask = X[:, 0, :] != self.PAD
if train:
Tok = tf.one_hot(self._mask(X[:, 0, :]), self.vocab_size, axis=-1)
else:
Tok = tf.one_hot(X[:, 0, :], self.vocab_size, axis=-1)
Pos = tf.one_hot(X[:, 1, :], self.max_length, axis=-1)
Seg = tf.one_hot(X[:, 2, :], 2, axis=-1)

Tokemb = self.tokEmb(Tok)
Posemb = self.posEmb(Pos)
Segemb = self.segEmb(Seg)

Word = Tokemb + Posemb + Segemb

res = Word

for block in self.blocks:
res = block(res, res, res, Mask)

return res
42 changes: 42 additions & 0 deletions supp/ML/jslpreprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2023 Abdelkrime Aries <[email protected]>
#
# ---- AUTHORS ----
# 2023 Abdelkrime Aries <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

ud_conllu_pattern = "(\d+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)\t([^\t]+)"

def parse_ud_conllu_file(url):
sents = []
words = []
with open(url, "r", encoding="utf8") as f:
for line in f:
m = re.match(ud_conllu_pattern, line)
if m:
if m.group(1) == "1":
words = []
sents.append(words)
words.append(m.group(2))

return sents


s = parse_ud_conllu_file("/home/kariminf/Research/UD.en/en_partut-ud-dev.conllu")

print(s)

0 comments on commit 7fbdebc

Please sign in to comment.