Skip to content

Commit

Permalink
Generate client configuration by default (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangxin369 authored Sep 26, 2022
1 parent aa25d15 commit cb7c3be
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ai_flow_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
jobs:
push_to_registry:
name: Nightly build and push packages
# if: github.repository == 'flink-extended/ai-flow' # Don't do this in forks
if: github.repository == 'flink-extended/ai-flow' # Don't do this in forks
runs-on: ubuntu-latest
steps:
- name: Check out the repo for nightly
Expand Down
11 changes: 7 additions & 4 deletions ai_flow/common/configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from copy import deepcopy
from typing import Dict, Any

from .helpers import TRUTH_TEXT, FALSE_TEXT, get_aiflow_home, parameterized_config
from .helpers import TRUTH_TEXT, FALSE_TEXT, get_aiflow_home, parameterized_config, write_default_config
from ai_flow.common.exception.exceptions import AIFlowConfigException
from ai_flow.common.util.file_util.yaml_utils import load_yaml_string
from ..env import expand_env_var
Expand Down Expand Up @@ -112,10 +112,13 @@ def __str__(self) -> str:

def get_client_configuration():
client_config_file_name = 'aiflow_client.yaml'
config_path = os.path.join(get_aiflow_home(), client_config_file_name)
home_dir = get_aiflow_home()
config_path = os.path.join(home_dir, client_config_file_name)
if not os.path.isfile(config_path):
logger.warning("Client configuration file not found in {}, using default.".format(config_path))
config_path = os.path.join(os.path.dirname(__file__), 'config_templates', client_config_file_name)
logger.warning("Client configuration file not found in {}, generating a default.".format(config_path))
if not os.path.exists(home_dir):
os.makedirs(home_dir)
write_default_config('aiflow_client.yaml')
config = Configuration(config_path)
return config

Expand Down
2 changes: 1 addition & 1 deletion tests/common/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def test_get_default_configuration(self):
server_conf = get_server_configuration()
self.assertEqual(server_conf.get_str('log_dir'), expand_env_var('~/aiflow/logs'))

write_default_config('aiflow_client.yaml')
client_conf = get_client_configuration()
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, 'aiflow_client.yaml')))
self.assertEqual(client_conf.get_str('server_address'), expand_env_var('localhost:50051'))
root_dir = expand_env_var('~/aiflow/blob')
self.assertEqual(client_conf.get('blob_manager').get('blob_manager_config'), {'root_directory': root_dir})

0 comments on commit cb7c3be

Please sign in to comment.