Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test basic AI model #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions src/fastspell/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import os
import io
import sys
import fasttext
import hunspell
import logging
import urllib.request
import pathlib
import timeit
import argparse
import traceback
import logging
from sklearn.feature_extraction import DictVectorizer
import pycountry
import xgboost
from sklearn.model_selection import GridSearchCV

try:
from . import __version__
from .util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config
from .fastspell import FastSpell
except ImportError:
from fastspell import __version__
from util import logging_setup, remove_unwanted_words, get_hash, check_dir, load_config
from fastspell import FastSpell

class FastSpellAI(FastSpell):
def __init__(self, lang, *args, **kwargs):
super().__init__(lang, *args, **kwargs)

ft_download_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
ft_model_path = "lid.176.bin"
if os.path.exists(ft_model_path):
ft_model = fasttext.load_model(ft_model_path)
else:
urllib.request.urlretrieve(ft_download_url, ft_model_path)
ft_model = fasttext.load_model(ft_model_path)

ft_prefix = "__label__"

fsobj = FastSpellAI("en")

languages = [label[len(ft_prefix):] for label in fsobj.model.get_labels()]

unsupported = []
hunspell_objs = {}
for language in languages:
try:
search_l = None
if language in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[language]
elif f"{language}_lat" in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[f"{language}_lat"]
elif f"{language}_cyr" in fsobj.hunspell_codes:
search_l = fsobj.hunspell_codes[f"{language}_cyr"]
else:
search_l = language
hunspell_objs[language] = fsobj.search_hunspell_dict(search_l)
except:
unsupported.append(language)

print(len(languages))
print(len(unsupported))
print(unsupported)

prediction = fsobj.model.predict("Ciao, mondo!".lower(), k=3)
print(prediction)
print(prediction[0])
print(prediction[0][0])
print(prediction[0][0][len(ft_prefix):])

sentences = []
labels = []
count = 0
with open("../sentences.csv", "r") as f:
for l in f:
number, language, text = next(f).split("\t")

if language != "ita":
continue

lang = pycountry.languages.get(alpha_3=language)

text = text.replace("\n", " ").strip()
prediction = fsobj.model.predict(text.lower(), k=3)

# print(prediction)

lang0 = prediction[0][0][len(ft_prefix):]
lang0_prob = prediction[1][0]
if len(prediction[0]) >= 2:
lang1 = prediction[0][1][len(ft_prefix):]
lang1_prob = prediction[1][1]
else:
# If there's only one option... Not much to do.
continue
if len(prediction[0]) >= 3:
lang2 = prediction[0][2][len(ft_prefix):]
lang2_prob = prediction[1][2]
else:
lang2 = None
lang2_prob = 0.0

label = None
if lang0 == lang.alpha_2:
label = 0
elif lang1 == lang.alpha_2:
label = 1
elif lang2 == lang.alpha_2:
label = 2

if label is None:
continue

# print(lang0)

raw_tokens = text.strip().split(" ")
if lang0 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang0)
correct = 0
for token in tokens:
try:
if hunspell_objs[lang0].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass
lang0_dic_tokens = correct / len(tokens)
else:
lang0_dic_tokens = None

if lang1 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang1)
correct = 0
for token in tokens:
try:
if hunspell_objs[lang1].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass
lang1_dic_tokens = correct / len(tokens)
else:
lang1_dic_tokens = None

if lang2 in hunspell_objs:
tokens = remove_unwanted_words(raw_tokens, lang2)

correct = 0
for token in tokens:
try:
if hunspell_objs[lang2].spell(token):
correct += 1
except UnicodeEncodeError as ex:
pass

lang2_dic_tokens = correct / len(tokens)
else:
lang2_dic_tokens = None

sentences.append({
"fastText_lang0": lang0_prob,
"fastText_lang1": lang1_prob,
"fastText_lang2": lang2_prob,
"lang0_dic_tokens": lang0_dic_tokens,
"lang1_dic_tokens": lang1_dic_tokens,
"lang2_dic_tokens": lang2_dic_tokens,
Comment on lines +160 to +165
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the input features for the model.

})
labels.append(label)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The label is 1 if the language to choose is the first among the three, 2 if it's the second, 3 if it's the third.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I did not see this. Maybe this changes a little bit my proposal.


# count += 1
# if count == 7:
# break

print(len(sentences))

dict_vectorizer = DictVectorizer()
X = dict_vectorizer.fit_transform(sentences)

xgb_model = xgboost.XGBClassifier(n_jobs=10)

clf = GridSearchCV(
xgb_model,
{"max_depth": [1, 2, 4, 6], "n_estimators": [25, 50, 100, 200]},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could presumably introduce another hyperparameter, a "threshold" above which we don't need to go look at the dictionary (e.g. if fastText is 99% sure it's Italian, we don't need to check further). This could lower the cost of the approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. I've always been interested if there is a correlation between confidence and precision. But I really never had time to see if the number of false positives or false negatives is very small when confidence is high. Therefore we could add this exception and speed it up.

verbose=1,
n_jobs=1,
)
clf.fit(X, labels)
print(clf.best_score_)
print(clf.best_params_)
print(clf.best_estimator_)

clf.best_estimator_.save_model("model.ubj")

X_try = dict_vectorizer.fit_transform([sentences[0]])
classes = xgb_model.predict(X)
if classes[0] == 0:
print("Lang0 chosen")
elif classes[0] == 1:
print("Lang2 chosen")
elif classes[0] == 2:
print("Lang3 chosen")
2 changes: 2 additions & 0 deletions src/fastspell/config/hunspell.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ hunspell_codes:
da: da_DK
de: de_DE
en: en_GB
el: el_GR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be cleaned up, right? We prefer to add new languages to the default config if the come along with the corresponding fastspell-dictionaries update that adds them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was just for testing.

es: es_ES
et: et_ET
fa: fa_IR
Expand Down Expand Up @@ -75,3 +76,4 @@ hunspell_codes:
ur: ur_PK
uz: uz_UZ
yi: yi
vi: vi_VN
43 changes: 39 additions & 4 deletions src/fastspell/fastspell.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
HBS_LANGS = ('hbs', 'sh', 'bs', 'sr', 'hr', 'me')


# logger = logging.getLogger()
# logger.setLevel(logging.DEBUG)

def initialization():
parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__)
parser.add_argument('lang', type=str)
Expand Down Expand Up @@ -94,6 +97,8 @@ def download_fasttext(self):

def search_hunspell_dict(self, lang_code):
''' Search in the paths for a hunspell dictionary and load it '''
hunspell_obj = None

for p in self.hunspell_paths:
if os.path.exists(f"{p}/{lang_code}.dic") and os.path.exists(f"{p}/{lang_code}.aff"):
try:
Expand All @@ -105,10 +110,36 @@ def search_hunspell_dict(self, lang_code):
logging.error("Failed building Hunspell object for " + lang_code)
logging.error("Aborting.")
exit(1)
else:
raise RuntimeError(f"It does not exist any valid dictionary directory"
f"for {lang_code} in the paths {self.hunspell_paths}."

if hunspell_obj is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is unrelated, I needed it to load as many dictionaries as possible from my system (as I was testing with Italian, and it is not available in fastspell-dictionaries).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this change could be useful in its own right, so users of fastspell can more easily load more dictionaries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you want to land this change (after cleaning it up of course) and I'll open a PR for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm not understanding this change correctly, but there's already a whole path search for possible dictionary candidates other than the fastspell-dictionaries if a user wants to use the system's. It's explained in the documentation and here is the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to allow loading a dictionary more easily, e.g. if "it" is the language and there's "it-IT.dic" in the system, it will assume that's the one to load (even though there might be others like "it_CH.dic").

for p in self.hunspell_paths:
if not os.path.exists(p):
continue

potential_files = [path for path in os.listdir(p) if os.path.basename(path).startswith(lang_code)]
if f"{lang_code}.dic" in potential_files and f"{lang_code}.aff" in potential_files:
dic = lang_code
elif f"{lang_code}_{lang_code.upper()}.dic" in potential_files and f"{lang_code}_{lang_code.upper()}.aff" in potential_files:
dic = f"{lang_code}_{lang_code.upper()}"
elif len(potential_files) == 2:
dic = potential_files[0][:-4]
else:
continue

try:
hunspell_obj = hunspell.Hunspell(dic, hunspell_data_dir=p)
logging.debug(f"Loaded hunspell obj for '{lang_code}' in path: {p + '/' + dic}")
break
except:
logging.error("Failed building Hunspell object for " + dic)
logging.error("Aborting.")
exit(1)

if hunspell_obj is None:
raise RuntimeError(f"It does not exist any valid dictionary directory "
f"for {lang_code} in the paths {self.hunspell_paths}. "
f"Please, execute 'fastspell-download'.")

return hunspell_obj


Expand All @@ -127,7 +158,7 @@ def load_hunspell_dicts(self):
self.similar = []
for sim_entry in self.similar_langs:
if sim_entry.split('_')[0] == self.lang:
self.similar.append(self.similar_langs[sim_entry])
self.similar.append(self.similar_langs[sim_entry] + [self.lang])
marco-c marked this conversation as resolved.
Show resolved Hide resolved

logging.debug(f"Similar lists for '{self.lang}': {self.similar}")
self.hunspell_objs = {}
Expand Down Expand Up @@ -208,6 +239,9 @@ def getlang(self, sent):

#TODO: Confidence score?

logging.debug(prediction)
logging.debug(self.similar)

if self.similar == [] or prediction not in self.hunspell_objs:
#Non mistakeable language: just return FastText prediction
refined_prediction = prediction
Expand All @@ -218,6 +252,7 @@ def getlang(self, sent):
for sim_list in self.similar:
if prediction in sim_list or f'{prediction}_{script}' in sim_list:
current_similar = sim_list
logging.debug(current_similar)

spellchecked = {}
for l in current_similar:
Expand Down