diff --git a/wenet/LLM/__init__.py b/wenet/LLM/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/__init__.py b/wenet/__init__.py index 820ad3180b..86fe62c04f 100644 --- a/wenet/__init__.py +++ b/wenet/__init__.py @@ -1 +1 @@ -from wenet.cli.model import load_model # noqa +from wenet.cli.model import load_model, load_model_pt # noqa diff --git a/wenet/bin/__init__.py b/wenet/bin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/cli/model.py b/wenet/cli/model.py index bb24bdb337..ad0165aca1 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -17,11 +17,13 @@ import torch import torchaudio import torchaudio.compliance.kaldi as kaldi +import yaml from wenet.cli.hub import Hub from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time, gen_timestamps_from_peak) from wenet.utils.file_utils import read_symbol_table +from wenet.utils.init_model import init_model from wenet.transformer.search import (attention_rescoring, ctc_prefix_beam_search, DecodeResult) from wenet.utils.context_graph import ContextGraph @@ -174,3 +176,34 @@ def load_model(language: str = None, model.device = torch.device(device) model.model.to(device) return model + +# Load the pytorch pt model which contains all the details compared with jit. +# And we can use the pt model as a third party pytorch nn.Module for training +def load_model_pt(model_dir): + """ There are the followi files in in `model_dir` + * final.pt, required + * train.yaml, required + * units.txt, required + * global_cmvn, optional + """ + # Check required files + required_files = ['train.yaml', 'final.pt', 'units.txt'] + for file in required_files: + file_path = os.path.join(model_dir, file) + if not os.path.exists(file_path): + raise FileNotFoundError( + f"Required file {file} not found in {model_dir}") + # Read config and override some config + config_file = os.path.join(model_dir, 'train.yaml') + with open(config_file, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + token_file = os.path.join(model_dir, 'units.txt') + configs['tokenizer_conf']['symbol_table_path'] = token_file + cmvn_file = os.path.join(model_dir, 'global_cmvn') + if os.path.exists(cmvn_file): + configs['cmvn_conf']['cmvn_file'] = cmvn_file + # Read model + pt_file = os.path.join(model_dir, 'final.pt') + args = {'checkpoint': pt_file} + model, configs = init_model(args, configs) + return model diff --git a/wenet/ctl_model/__init__.py b/wenet/ctl_model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/dataset/deprecated/__init__.py b/wenet/dataset/deprecated/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/e_branchformer/__init__.py b/wenet/e_branchformer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/finetune/__init__.py b/wenet/finetune/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/ssl/__init__.py b/wenet/ssl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/ssl/bestrq/__init__.py b/wenet/ssl/bestrq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/ssl/w2vbert/__init__.py b/wenet/ssl/w2vbert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/ssl/wav2vec2/__init__.py b/wenet/ssl/wav2vec2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/transducer/search/__init__.py b/wenet/transducer/search/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 3a553994b4..cbcf2f528b 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -213,12 +213,13 @@ def init_model(args, configs): if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: load_checkpoint(model, args.lora_ckpt_path) - print(configs) # Trye to tie some weights if hasattr(model, 'tie_or_clone_weights'): if not hasattr(args, 'jit'): - args.jit = True # i.e. export onnx/jit/ipex - model.tie_or_clone_weights(args.jit) + jit = True # i.e. export onnx/jit/ipex + else: + jit = False + model.tie_or_clone_weights(jit) if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: mark_only_lora_as_trainable(model, bias='lora_only')