diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 535182d214..b8f69b54e5 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -2,6 +2,7 @@ import importlib import logging import os +import sys from argparse import RawTextHelpFormatter import numpy as np @@ -18,7 +19,7 @@ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) # pylint: disable=bad-option-value parser = argparse.ArgumentParser( diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 1bdb8d733c..dc0ce5b18b 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,6 +1,7 @@ import argparse import logging import os +import sys from argparse import RawTextHelpFormatter import torch @@ -102,7 +103,7 @@ def compute_embeddings( if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser( description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index dc5423a691..acec91c369 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -5,6 +5,7 @@ import glob import logging import os +import sys import numpy as np from tqdm import tqdm @@ -18,7 +19,7 @@ def main(): """Run preprocessing process.""" - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stderr, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 711c8221db..701c7d8e82 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -1,5 +1,6 @@ import argparse import logging +import sys from argparse import RawTextHelpFormatter import torch @@ -53,7 +54,7 @@ def compute_encoder_accuracy(dataset_items, encoder_manager): if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser( description="""Compute the accuracy of the encoder.\n\n""" diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 86a4dce177..a04005ce39 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -4,6 +4,7 @@ import argparse import logging import os +import sys import numpy as np import torch @@ -273,7 +274,7 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 0519d43769..7a7fdf5dd4 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -2,6 +2,7 @@ import argparse import logging +import sys from argparse import RawTextHelpFormatter from TTS.config import load_config @@ -10,7 +11,7 @@ def main(): - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) # pylint: disable=bad-option-value parser = argparse.ArgumentParser( diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index d99acb9893..7c68fdb070 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -3,6 +3,7 @@ import argparse import logging import multiprocessing +import sys from argparse import RawTextHelpFormatter from tqdm.contrib.concurrent import process_map @@ -20,7 +21,7 @@ def compute_phonemes(item): def main(): - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) # pylint: disable=W0601 global c, phonemizer diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index edab882db8..f9121d7f77 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -4,6 +4,7 @@ import multiprocessing import os import pathlib +import sys import torch from tqdm import tqdm @@ -77,7 +78,7 @@ def preprocess_audios(): if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser( description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end" diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 47b442e266..5d20db6a59 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -311,8 +311,9 @@ def parse_args() -> argparse.Namespace: def main() -> None: """Entry point for `tts` command line interface.""" - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) args = parse_args() + stream = sys.stderr if args.pipe_out else sys.stdout + setup_logger("TTS", level=logging.INFO, stream=stream, formatter=ConsoleFormatter()) pipe_out = sys.stdout if args.pipe_out else None diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index ba03c42b6d..84123d2db3 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -322,7 +322,7 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 6d6342a762..e93b1c9d24 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,5 +1,6 @@ import logging import os +import sys from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs @@ -17,7 +18,7 @@ class TrainTTSArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) # init trainer args train_args = TrainTTSArgs() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 221ff4cff0..aa04177068 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,5 +1,6 @@ import logging import os +import sys from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs @@ -18,7 +19,7 @@ class TrainVocoderArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) # init trainer args train_args = TrainVocoderArgs() diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index df2923952d..d05ae14b7f 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -2,6 +2,7 @@ import argparse import logging +import sys from itertools import product as cartesian_product import numpy as np @@ -17,7 +18,7 @@ from TTS.vocoder.models import setup_model if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py index da7522a512..37619ed0f8 100644 --- a/TTS/encoder/utils/prepare_voxceleb.py +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -216,7 +216,7 @@ def processor(directory, subset, force_process): if __name__ == "__main__": - setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) if len(sys.argv) != 4: print("Usage: python prepare_data.py save_directory user password") sys.exit() diff --git a/TTS/server/server.py b/TTS/server/server.py index f410fb7539..6a4642f9a2 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -25,7 +25,7 @@ from TTS.utils.synthesizer import Synthesizer logger = logging.getLogger(__name__) -setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) +setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) def create_argparser() -> argparse.ArgumentParser: diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index d7397f673d..54bb5ba825 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -2,9 +2,10 @@ import datetime import importlib import logging +import os import re from pathlib import Path -from typing import Any, Callable, Dict, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Optional, TextIO, TypeVar, Union import torch from packaging.version import Version @@ -107,25 +108,34 @@ def setup_logger( level: int = logging.INFO, *, formatter: Optional[logging.Formatter] = None, - screen: bool = False, - tofile: bool = False, - log_dir: str = "logs", + stream: Optional[TextIO] = None, + log_dir: Optional[Union[str, os.PathLike[Any]]] = None, log_name: str = "log", ) -> None: + """Set up a logger. + + Args: + logger_name: Name of the logger to set up + level: Logging level + formatter: Formatter for the logger + stream: Add a StreamHandler for the given stream, e.g. sys.stderr or sys.stdout + log_dir: Folder to write the log file (no file created if None) + log_name: Prefix of the log file name + """ lg = logging.getLogger(logger_name) if formatter is None: formatter = logging.Formatter( "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" ) lg.setLevel(level) - if tofile: + if log_dir is not None: Path(log_dir).mkdir(exist_ok=True, parents=True) log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) - if screen: - sh = logging.StreamHandler() + if stream is not None: + sh = logging.StreamHandler(stream) sh.setFormatter(formatter) lg.addHandler(sh)