Skip to content

Commit

Permalink
Working, except in python
Browse files Browse the repository at this point in the history
  • Loading branch information
XapaJIaMnu committed Aug 2, 2023
1 parent 1a8b90c commit 1e80e79
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 18 deletions.
10 changes: 8 additions & 2 deletions bindings/python/bergamot.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// #define PYBIND11_DETAILED_ERROR_MESSAGES // Enables debugging
#include <pybind11/iostream.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -29,6 +30,7 @@ using Alignment = std::vector<std::vector<float>>;
using Alignments = std::vector<Alignment>;

PYBIND11_MAKE_OPAQUE(std::vector<Response>);
PYBIND11_MAKE_OPAQUE(std::vector<size_t>);
PYBIND11_MAKE_OPAQUE(std::vector<std::string>);
PYBIND11_MAKE_OPAQUE(std::unordered_map<std::string, std::string>);
PYBIND11_MAKE_OPAQUE(Alignments);
Expand Down Expand Up @@ -212,22 +214,26 @@ PYBIND11_MODULE(_bergamot, m) {
.def("pivot", &ServicePyAdapter::pivot)
.def("setTerminology", &ServicePyAdapter::setTerminology);

py::bind_vector<std::vector<size_t>>(m, "VectorSizeT");
py::class_<Service::Config>(m, "ServiceConfig")
.def(py::init<>([](size_t numWorkers, size_t cacheSize, std::string logging, std::string pathToTerminologyFile,
.def(py::init<>([](size_t numWorkers, std::vector<size_t> gpuWorkers, size_t cacheSize, std::string logging, std::string pathToTerminologyFile,
bool terminologyForce, std::string terminologyForm) {
Service::Config config;
config.numWorkers = numWorkers;
config.gpuWorkers = gpuWorkers;
config.cacheSize = cacheSize;
config.logger.level = logging;
config.terminologyFile = pathToTerminologyFile;
config.terminologyForce = terminologyForce;
config.format = terminologyForm;
return config;
}),
py::arg("numWorkers") = 1, py::arg("cacheSize") = 0, py::arg("logLevel") = "off",
py::arg("numWorkers") = 1, py::arg("gpuWorkers") = std::vector<size_t>{0},
py::arg("cacheSize") = 0, py::arg("logLevel") = "off",
py::arg("pathToTerminologyFile") = "", py::arg("terminologyForce") = false,
py::arg("terminologyForm") = "%s <tag0> %s </tag0> ")
.def_readwrite("numWorkers", &Service::Config::numWorkers)
.def_readwrite("gpuWorkers", &Service::Config::gpuWorkers)
.def_readwrite("cacheSize", &Service::Config::cacheSize)
.def_readwrite("pathToTerminologyFile", &Service::Config::terminologyFile)
.def_readwrite("terminologyForce", &Service::Config::terminologyForce)
Expand Down
31 changes: 26 additions & 5 deletions bindings/python/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Translator:
Attributes:
_num_workers Number of parallel CPU workers.
_gpu_workers Indices of the GPU devices used. _num_workers must be set to zero!
_cache: Cache size. 0 to disable cache.
_logging: Log level: trace, debug, info, warn, err(or), critical, off. Default is off
_terminology: Path to a TSV terminology file
Expand All @@ -21,6 +22,7 @@ class Translator:
_service The translation service
"""
_num_workers: int
_gpu_workers: List[int]
_cache: int
_logging: str
_terminology: str
Expand All @@ -32,26 +34,28 @@ class Translator:
_responseOpts: bergamot.ResponseOptions
_service: bergamot.Service

def __init__(self, model_conifg_path: str, num_workers: int=1, cache: int=0, \
def __init__(self, model_conifg_path: str, num_workers: int=1, gpu_workers: List[int]=[], cache: int=0, \
logging="off", terminology: str="", force_terminology: bool=False,\
terminology_form: str="%s <tag0> %s </tag0> "):
"""Initialises the translator class
:param model_conifg_path: Path to the configuration file for the translation model.
:param num_workers: Number of CPU workers.
:param gpu_workers: Indices of the GPU devices. num_workers must be zero if this is non-empty
:param cache: cache size. 0 means no cache.
:param logging: Log level: trace, debug, info, warn, err(or), critical, off.
:param terminology: Path to terminology file, TSV format
:param force_terminology: Force terminology to appear on the target side. May impact translation quality.
"""
self._num_workers = num_workers
self._gpu_workers = gpu_workers
self._cache = cache
self._logging = logging
self._terminology = terminology
self._force_terminology = force_terminology
self._terminology_form = terminology_form

self._config = bergamot.ServiceConfig(self._num_workers, self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._config = bergamot.ServiceConfig(self._num_workers, bergamot.VectorSizeT(self._gpu_workers), self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._service = bergamot.Service(self._config)
self._responseOpts = bergamot.ResponseOptions() # Default false for all, if we want to enable HTML later, from here
self._model = self._service.modelFromConfigPath(model_conifg_path)
Expand All @@ -64,7 +68,7 @@ def reset_terminology(self, terminology: str="", force_terminology: bool=False)
"""
self._terminology = terminology
self._force_terminology = force_terminology
self._config = bergamot.ServiceConfig(self._num_workers, self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._config = bergamot.ServiceConfig(self._num_workers, bergamot.VectorSizeT(self._gpu_workers), self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._service = bergamot.Service(self._config)

def reset_terminology(self, terminology: Dict[str,str], force_terminology: bool=False) -> None:
Expand All @@ -81,7 +85,16 @@ def reset_num_workers(self, num_workers) -> None:
:return: None
"""
self._num_workers = num_workers
self._config = bergamot.ServiceConfig(self._num_workers, self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._config = bergamot.ServiceConfig(self._num_workers, bergamot.VectorSizeT(self._gpu_workers), self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._service = bergamot.Service(self._config)

def reset_gpu_workers(self, gpu_workers: List[int]) -> None:
"""Resets the number of GPU workers
:param gpu_workers: Indices of the GPU devices to be used.
:return: None
"""
self._gpu_workers = gpu_workers
self._config = bergamot.ServiceConfig(self._num_workers, bergamot.VectorSizeT(self._gpu_workers), self._cache, self._logging, self._terminology, self._force_terminology, self._terminology_form)
self._service = bergamot.Service(self._config)

def translate(self, sentences: List[str]) -> List[str]:
Expand All @@ -98,6 +111,7 @@ def main():
parser = argparse.ArgumentParser(description="bergamot-translator interface")
parser.add_argument("--config", '-c', required=True, type=str, help='Model YML configuration input.')
parser.add_argument("--num-workers", '-n', type=int, default=1, help='Number of CPU workers.')
parser.add_argument("--num-gpus", "-g", type=int, action='append', nargs='+', default=None, help='List of GPUs to use.')
parser.add_argument("--logging", '-l', type=str, default="off", help='Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off. Default is off')
parser.add_argument("--cache-size", type=int, default=0, help='Cache size. 0 for caching is disabled')
parser.add_argument("--terminology-tsv", '-t', default="", type=str, help='Path to a terminology file TSV')
Expand All @@ -107,7 +121,14 @@ def main():
parser.add_argument("--batch", '-b', default=32, type=int, help="Number of lines to process in a batch")
args = parser.parse_args()

translator = Translator(args.config, args.num_workers, args.cache_size, args.logging, args.terminology_tsv, args.force_terminology, args.terminology_form)
if args.num_gpus is None:
num_gpus = []
else:
num_gpus = args.num_gpus[0]
print(num_gpus)
print(type(num_gpus))
print(args.num_workers)
translator = Translator(args.config, args.num_workers, num_gpus, args.cache_size, args.logging, args.terminology_tsv, args.force_terminology, args.terminology_form)

if args.path_to_input is None:
infile = stdin
Expand Down
13 changes: 10 additions & 3 deletions src/translator/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,16 @@ AsyncService::AsyncService(const AsyncService::Config &config)
safeBatchingPool_(),
cache_(makeOptionalCache(config_.cacheSize, /*mutexBuckets=*/config_.numWorkers)),
logger_(config.logger) {
ABORT_IF(config_.numWorkers == 0, "Number of workers should be at least 1 in a threaded workflow");
workers_.reserve(config_.numWorkers);
if (config_.gpuWorkers.size() != 0) {
ABORT_IF(config_.numWorkers != 0, "Unable to mix GPU and CPU workers.");
workers_.reserve(config_.gpuWorkers.size());
// VERY VERY HACKY. EVERYTHING USES NUM_WORKERS AS A REFERENCE FOR THE NUMBER OF WORKERS,
// REFACTOR TO USE gpuWorkers directly...
config_.numWorkers = config_.gpuWorkers.size();
} else {
ABORT_IF(config_.numWorkers == 0, "Number of workers should be at least 1 in a threaded workflow");
workers_.reserve(config_.numWorkers);
}
// Initiate terminology map if present
if (!config_.terminologyFile.empty()) {
// Create an input filestream
Expand Down Expand Up @@ -283,7 +291,6 @@ void AsyncService::translate(std::shared_ptr<TranslationModel> translationModel,
html->restore(response);
callback(std::move(response));
};

translateRaw(translationModel, std::move(source), internalCallback, responseOptions);
}

Expand Down
4 changes: 3 additions & 1 deletion src/translator/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class BlockingService {
class AsyncService {
public:
struct Config {
std::vector<size_t> gpuWorkers; ///< GPU workers array. If not-empty use CPU workers instead.
size_t numWorkers{1}; ///< How many worker translation threads to spawn.
size_t cacheSize{0}; ///< Size in History items to be stored in the cache. Loosely corresponds to sentences to
/// cache in the real world. A value of 0 means no caching.
Expand All @@ -120,6 +121,7 @@ class AsyncService {
template <class App>
static void addOptions(App &app, Config &config) {
app.add_option("--cpu-threads", config.numWorkers, "Workers to form translation backend");
app.add_option("--gpu-workers", config.gpuWorkers, "GPU workers for the translation backend.");
app.add_option("--cache-size", config.cacheSize, "Number of entries to store in cache.");
app.add_option("--terminology-file", config.terminologyFile, "tsv, one term at a time terminology file.");
app.add_option(
Expand All @@ -138,7 +140,7 @@ class AsyncService {
/// backend needed based on worker threads set. See TranslationModel for documentation on other params.
Ptr<TranslationModel> createCompatibleModel(const TranslationModel::Config &config) {
// @TODO: Remove this remove this dependency/coupling.
return New<TranslationModel>(config, /*replicas=*/config_.numWorkers);
return New<TranslationModel>(config, /*replicas=*/config_.numWorkers, config_.gpuWorkers);
}

/// With the supplied TranslationModel, translate an input. A Response is constructed with optional items set/unset
Expand Down
11 changes: 9 additions & 2 deletions src/translator/translation_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ namespace bergamot {
std::atomic<size_t> TranslationModel::modelCounter_ = 0;

TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory /*=MemoryBundle{}*/,
size_t replicas /*=1*/)
size_t replicas /*=1*/, std::vector<size_t> gpus /*={}*/)
: modelId_(modelCounter_++),
options_(options),
memory_(std::move(memory)),
vocabs_(options, std::move(memory_.vocabs)),
textProcessor_(options, vocabs_, std::move(memory_.ssplitPrefixFile)),
batchingPool_(options),
gpus_{gpus},
qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memory, options))) {
ABORT_IF(replicas == 0, "At least one replica needs to be created.");
backend_.resize(replicas);
Expand Down Expand Up @@ -53,7 +54,13 @@ void TranslationModel::loadBackend(size_t idx) {
auto &graph = backend_[idx].graph;
auto &scorerEnsemble = backend_[idx].scorerEnsemble;

marian::DeviceId device_(idx, DeviceType::cpu);
marian::DeviceId device_;

if (gpus_.empty()) {
device_ = marian::DeviceId(idx, DeviceType::cpu);
} else {
device_ = marian::DeviceId(gpus_[idx], DeviceType::gpu);
}
graph = New<ExpressionGraph>(/*inference=*/true); // set the graph to be inference only
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(prec[0]));
Expand Down
12 changes: 7 additions & 5 deletions src/translator/translation_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class TranslationModel {
/// operandi.
///
/// TODO(@jerinphilip): Clean this up.
TranslationModel(const std::string& config, MemoryBundle&& memory, size_t replicas = 1)
: TranslationModel(parseOptionsFromString(config, /*validate=*/false), std::move(memory), replicas){};
TranslationModel(const std::string& config, MemoryBundle&& memory, size_t replicas = 1, std::vector<size_t> gpus = {})
: TranslationModel(parseOptionsFromString(config, /*validate=*/false), std::move(memory), replicas, gpus){};

/// Construct TranslationModel from marian-options. If memory is empty, TranslationModel is initialized from
/// paths available in the options object, backed by filesystem. Otherwise, TranslationModel is initialized from the
Expand All @@ -57,10 +57,11 @@ class TranslationModel {
/// @param [in] options: Marian options object.
/// @param [in] memory: MemoryBundle object holding memory buffers containing parameters to build MarianBackend,
/// ShortlistGenerator, Vocabs and SentenceSplitter.
TranslationModel(const Config& options, MemoryBundle&& memory, size_t replicas = 1);
/// @param [in] gpus: Optional array of GPU ids
TranslationModel(const Config& options, MemoryBundle&& memory, size_t replicas = 1, std::vector<size_t> gpus = {});

TranslationModel(const Config& options, size_t replicas = 1)
: TranslationModel(options, getMemoryBundleFromConfig(options), replicas) {}
TranslationModel(const Config& options, size_t replicas = 1, std::vector<size_t> gpus = {})
: TranslationModel(options, getMemoryBundleFromConfig(options), replicas, gpus) {}

/// Make a Request to be translated by this TranslationModel instance.
/// @param [in] requestId: Unique identifier associated with this request, available from Service.
Expand Down Expand Up @@ -103,6 +104,7 @@ class TranslationModel {
MemoryBundle memory_;
Vocabs vocabs_;
TextProcessor textProcessor_;
std::vector<size_t> gpus_;

/// Maintains sentences from multiple requests bucketed by length and sorted by priority in each bucket.
BatchingPool batchingPool_;
Expand Down

0 comments on commit 1e80e79

Please sign in to comment.