From 71b914f07aa8a1bb2cb2424ad393946879db33c2 Mon Sep 17 00:00:00 2001 From: T5X Team Date: Mon, 15 Apr 2024 11:45:43 -0700 Subject: [PATCH] Added MAXO-GEN models, namely, SAC, AGIQA-1k, AGIQA-3k and AIGCIQA-2023. PiperOrigin-RevId: 625032373 --- t5x/export_lib.py | 1 + t5x/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/t5x/export_lib.py b/t5x/export_lib.py index bb4d2720e..f0448192c 100644 --- a/t5x/export_lib.py +++ b/t5x/export_lib.py @@ -1471,6 +1471,7 @@ def save( train_state_initializer = get_train_state_initializer( model, partitioner, task_feature_lengths, batch_size, trailing_shapes ) + utils.import_module('pmmx.projects.maxo.export_task') output_features = _standardize_output_features( mixture_or_task_name, output_features diff --git a/t5x/utils.py b/t5x/utils.py index 5c55945cd..0070c1005 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -39,6 +39,7 @@ import flax.core from flax.core import scope as flax_scope from flax.linen import partitioning as flax_partitioning +import gin import jax from jax.experimental import multihost_utils import jax.numpy as jnp @@ -55,7 +56,6 @@ from tensorflow.io import gfile import typing_extensions - FLAGS = flags.FLAGS # Remove _ShardedDeviceArray when users of t5x have their types updated @@ -1818,7 +1818,8 @@ def import_module(module: str): """Imports the given module at runtime.""" logging.info('Importing %s.', module) try: - importlib.import_module(module) + with gin.unlock_config(): + importlib.import_module(module) except RuntimeError as e: if ( str(e)