From cb7c3be6f0463ad4976b34273386d560916e0329 Mon Sep 17 00:00:00 2001 From: JiangXin Date: Mon, 26 Sep 2022 12:37:04 +0800 Subject: [PATCH] Generate client configuration by default (#386) --- .github/workflows/ai_flow_cd.yml | 2 +- ai_flow/common/configuration/configuration.py | 11 +++++++---- tests/common/configuration/test_configuration.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ai_flow_cd.yml b/.github/workflows/ai_flow_cd.yml index 982720b47..0ac571ce9 100644 --- a/.github/workflows/ai_flow_cd.yml +++ b/.github/workflows/ai_flow_cd.yml @@ -7,7 +7,7 @@ on: jobs: push_to_registry: name: Nightly build and push packages - # if: github.repository == 'flink-extended/ai-flow' # Don't do this in forks + if: github.repository == 'flink-extended/ai-flow' # Don't do this in forks runs-on: ubuntu-latest steps: - name: Check out the repo for nightly diff --git a/ai_flow/common/configuration/configuration.py b/ai_flow/common/configuration/configuration.py index d3580d286..d6ca57403 100644 --- a/ai_flow/common/configuration/configuration.py +++ b/ai_flow/common/configuration/configuration.py @@ -19,7 +19,7 @@ from copy import deepcopy from typing import Dict, Any -from .helpers import TRUTH_TEXT, FALSE_TEXT, get_aiflow_home, parameterized_config +from .helpers import TRUTH_TEXT, FALSE_TEXT, get_aiflow_home, parameterized_config, write_default_config from ai_flow.common.exception.exceptions import AIFlowConfigException from ai_flow.common.util.file_util.yaml_utils import load_yaml_string from ..env import expand_env_var @@ -112,10 +112,13 @@ def __str__(self) -> str: def get_client_configuration(): client_config_file_name = 'aiflow_client.yaml' - config_path = os.path.join(get_aiflow_home(), client_config_file_name) + home_dir = get_aiflow_home() + config_path = os.path.join(home_dir, client_config_file_name) if not os.path.isfile(config_path): - logger.warning("Client configuration file not found in {}, using default.".format(config_path)) - config_path = os.path.join(os.path.dirname(__file__), 'config_templates', client_config_file_name) + logger.warning("Client configuration file not found in {}, generating a default.".format(config_path)) + if not os.path.exists(home_dir): + os.makedirs(home_dir) + write_default_config('aiflow_client.yaml') config = Configuration(config_path) return config diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 2a5340cfa..60b7c72d6 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -124,8 +124,8 @@ def test_get_default_configuration(self): server_conf = get_server_configuration() self.assertEqual(server_conf.get_str('log_dir'), expand_env_var('~/aiflow/logs')) - write_default_config('aiflow_client.yaml') client_conf = get_client_configuration() + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, 'aiflow_client.yaml'))) self.assertEqual(client_conf.get_str('server_address'), expand_env_var('localhost:50051')) root_dir = expand_env_var('~/aiflow/blob') self.assertEqual(client_conf.get('blob_manager').get('blob_manager_config'), {'root_directory': root_dir})