Skip to content

Commit

Permalink
Defer model loading to parallel worker thread (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmervdl authored Jan 14, 2022
1 parent 71b84b7 commit 13c55e2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/translator/translation_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory
srcIdx, trgIdx, shared_vcb);
}
}

for (size_t idx = 0; idx < replicas; idx++) {
loadBackend(idx);
}
}

void TranslationModel::loadBackend(size_t idx) {
Expand Down Expand Up @@ -172,6 +168,12 @@ Ptr<marian::data::CorpusBatch> TranslationModel::convertToMarianBatch(Batch &bat

void TranslationModel::translateBatch(size_t deviceId, Batch &batch) {
auto &backend = backend_[deviceId];

if (!backend.initialized) {
loadBackend(deviceId);
backend.initialized = true;
}

BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target());
Histories histories = search.search(backend.graph, convertToMarianBatch(batch));
batch.completeBatch(histories);
Expand Down
1 change: 1 addition & 0 deletions src/translator/translation_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TranslationModel {

Graph graph;
ScorerEnsemble scorerEnsemble;
bool initialized{false};
};

// ShortlistGenerator is purely const, we don't need one per thread.
Expand Down

0 comments on commit 13c55e2

Please sign in to comment.