Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Adding python callback during training
Browse files Browse the repository at this point in the history
  • Loading branch information
PreetGits authored and nehagaur01 committed May 12, 2023
1 parent 3697152 commit adc0b8d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/fasttext_module/fasttext/FastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def train_supervised(*kargs, **kwargs):
'model': "supervised"
})

callback = kwargs.pop("callback", None)
arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
Expand All @@ -525,7 +526,10 @@ def train_supervised(*kargs, **kwargs):
supervised_default)
a = _build_args(args, manually_set_args)
ft = _FastText(args=a)
fasttext.train(ft.f, a)
if callback:
fasttext.train_with_callback(ft.f, a, callback)
else:
fasttext.train(ft.f, a)
ft.set_args(ft.f.getArgs())
return ft

Expand All @@ -544,13 +548,18 @@ def train_unsupervised(*kargs, **kwargs):
dataset pulled by the example script word-vector-example.sh, which is
part of the fastText repository.
"""
callback = kwargs.pop("callback", None)
arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
args, manually_set_args = read_args(kargs, kwargs, arg_names,
unsupervised_default)
a = _build_args(args, manually_set_args)
ft = _FastText(args=a)
if callback:
fasttext.train_with_callback(ft.f, a, callback)
else:
fasttext.train(ft.f, a)
fasttext.train(ft.f, a)
ft.set_args(ft.f.getArgs())
return ft
Expand Down
8 changes: 8 additions & 0 deletions python/fasttext_module/fasttext/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <fasttext.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <real.h>
#include <vector.h>
Expand Down Expand Up @@ -166,6 +167,13 @@ PYBIND11_MODULE(fasttext_pybind, m) {
}
},
py::call_guard<py::gil_scoped_release>());

m.def(
"train_with_callback",
[](fasttext::FastText& ft, fasttext::Args& a, fasttext::FastText::TrainCallback& c) {
ft.train(a, c);
},
py::call_guard<py::gil_scoped_release>());

py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
.def(py::init<ssize_t>())
Expand Down

0 comments on commit adc0b8d

Please sign in to comment.