Skip to content

Commit

Permalink
fix another edge case, and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brianfitzgerald committed Aug 10, 2023
1 parent 25d92f9 commit 5b01ef3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/stability_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,23 +353,22 @@ def parse_models_from_prompts(prompts: Union[Any, List[Any]]) -> Tuple[List[Any]
if not prompts:
return [], []
prompts = prompts if isinstance(prompts, List) else [prompts]
pattern = re.compile(r"<([^:]+):([^>]+)>")
pattern = re.compile(r"<([^:>]+)(?::([^>]+))?>")
models = {}

def _process_prompt(prompt):
text = prompt.text if isinstance(prompt, generation.Prompt) else prompt
matches = pattern.findall(text)
for model, weight in matches:
weight_text = weight if weight else ""
# pass default TI tokens through unmodified
if model in ["s1", "s2", "s3"]:
continue
try:
weight = max(float(weight) if weight else 1.0, models.get(model, -math.inf))
if weight == 1.0:
weight = int(weight)
except ValueError as e:
raise ValueError(f'Invalid weight for model "{model}": "{weight}"') from e
text = text.replace(f'<{model}:{weight}>', f'<{model}>')
text = text.replace(f'<{model}:{weight_text}>', f'<{model}>')
models[model] = weight
return text

Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def test_color_match_from_string_invalid():
def test_parse_models_from_prompts():
assert parse_models_from_prompts(None) == ([], [])
assert parse_models_from_prompts([]) == ([], [])
assert parse_models_from_prompts("a <one:1>")[1] == [("one", 1.0)]
assert parse_models_from_prompts("a <one:1.0>")[1] == [("one", 1.0)]
assert parse_models_from_prompts("a simple prompt") == (["a simple prompt"], [])
assert parse_models_from_prompts("<weight-strip:0.25>") == (["<weight-strip>"], [("weight-strip", 0.25)])
assert parse_models_from_prompts("a <my-model>")[1] == [("my-model", 1.0)]
Expand Down

0 comments on commit 5b01ef3

Please sign in to comment.