diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fc67a06..136164e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,7 +35,7 @@ jobs: strategy: matrix: group: [ 1, 2, 3 ] - tensorflow_version: [2.2, 2.3, 2.4, 2.5] + tensorflow_version: [2.1, 2.2, 2.3, 2.4, 2.5] steps: - uses: actions/checkout@v2 - name: Set up Python 3.8 diff --git a/kashgari/__init__.py b/kashgari/__init__.py index 2509f01..289d657 100644 --- a/kashgari/__init__.py +++ b/kashgari/__init__.py @@ -12,25 +12,40 @@ """ import os -from typing import Dict, Any +from distutils.version import LooseVersion +from typing import Any, Dict os.environ['TF_KERAS'] = '1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' custom_objects: Dict[str, Any] = {} +def check_tfa_version(tf_version): + if LooseVersion(tf_version) < '2.2.0': + return '0.9.1' + elif LooseVersion(tf_version) < '2.3.0': + return '0.11.2' + else: + return '0.13.0' + + +def dependency_check() -> None: + import tensorflow as tf + tfa_version = check_tfa_version(tf_version=tf.__version__) + try: + import tensorflow_addons as tfa + except: + raise ImportError( + "Kashgari request tensorflow_addons, please install via the " + f"`$pip install tensorflow_addons=={tfa_version}`" + ) + +dependency_check() + +from kashgari import corpus, embeddings, layers, macros, processors, tasks, utils from kashgari.__version__ import __version__ from kashgari.macros import config -from kashgari import layers -from kashgari import corpus -from kashgari import embeddings -from kashgari import macros -from kashgari import processors -from kashgari import tasks -from kashgari import utils - -from kashgari.utils.dependency_check import dependency_check custom_objects = layers.resigter_custom_layers(custom_objects) -dependency_check() + diff --git a/kashgari/utils/dependency_check.py b/kashgari/utils/dependency_check.py deleted file mode 100644 index 1b69525..0000000 --- a/kashgari/utils/dependency_check.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Author : BrikerMan -# Site : https://eliyar.biz - -# Time : 2020/9/2 12:12 下午 -# File : dependency_check.py -# Project : Kashgari - -import tensorflow as tf - -from distutils.version import LooseVersion - - -def dependency_check() -> None: - if LooseVersion(tf.__version__) < '2.2.0': - try: - import tensorflow_addons as tfa - if LooseVersion(tfa.__version__) > '0.10.0': - raise ImportError("TF 2.1 required lower version of tensorflow_addons, " - "install using `$pip install tensorflow_addons<=0.10.0`") - except ImportError: - raise ImportError("TF 2.1 required lower version of tensorflow_addons, " - "install using `$pip install tensorflow_addons<=0.10.0`")