Skip to content

Commit

Permalink
0.9.38 fix create_grid_params
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Nov 28, 2023
1 parent 42ed342 commit a3c59ba
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions czsc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def freqs_sorted(freqs):
return _freqs_new


def create_grid_params(prefix: str, detail=False, **kwargs) -> dict:
def create_grid_params(prefix: str = "", multiply=3, **kwargs) -> dict:
"""创建 grid search 参数组合
:param prefix: 参数组前缀
:param detail: 是否使用参数值构建参数组的名称
:param multiply: 参数组合的位数,如果为 0,则使用 # 分隔参数
:param kwargs: 任意参数的候选序列,参数值推荐使用 iterable
:return: 参数组合字典
Expand All @@ -111,8 +111,8 @@ def create_grid_params(prefix: str, detail=False, **kwargs) -> dict:
>>>x = create_grid_params("test", x=2, y=('a', 'b'), detail=False)
>>>print(x)
Out[1]:
{'test@001': {'x': 2, 'y': 'a'},
'test@002': {'x': 2, 'y': 'b'}}
{'test001': {'x': 2, 'y': 'a'},
'test002': {'x': 2, 'y': 'b'}}
"""
from sklearn.model_selection import ParameterGrid

Expand All @@ -126,13 +126,12 @@ def create_grid_params(prefix: str, detail=False, **kwargs) -> dict:

params = {}
for i, row in enumerate(ParameterGrid(params_grid), 1):
if detail:
if multiply == 0:
key = "#".join([f"{k}={v}" for k, v in row.items()])
# params[f"{prefix}@{key}"] = row
else:
key = str(i).zfill(3)
key = str(i).zfill(multiply)

row['version'] = f"{prefix}@{key}"
row['version'] = f"{prefix}{key}"
params[f"{prefix}@{key}"] = row
return params

Expand Down

0 comments on commit a3c59ba

Please sign in to comment.