Skip to content

Commit

Permalink
prevent runtime override of model
Browse files Browse the repository at this point in the history
  • Loading branch information
Douglas Reid committed Nov 17, 2023
1 parent 8d6e5f9 commit 6896d61
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ def run(
return InvocableResponse(data=RawBlockAndTagPluginOutput(blocks=generated_blocks))

def _inputs_from_config_and_runtime_params(self, options: Optional[dict]) -> dict:

if options is not None and "model" in options:
raise SteamshipError(
"Model may not be overridden in runtime options. "
"Please configure 'model' when creating a plugin instance."
)

temp_config = DallEPlugin.DallEPluginConfig(**self.config.dict())
temp_config.extend_with_dict(options, overwrite=True)
validated_config = DallEPlugin.DallEPluginConfig(**temp_config.dict())
Expand Down
7 changes: 7 additions & 0 deletions test/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ def test_runtime_config_validation_returns_values():
assert config_dict.get("size") == ImageSizeEnum.large
assert config_dict.get("n") == 1
assert config_dict.get("style")


def test_no_override_of_model_at_runtime():
plugin = DallEPlugin(config={"model": "dall-e-2"})

with pytest.raises(SteamshipError):
plugin._inputs_from_config_and_runtime_params(options={"model": "dall-e-3"})

0 comments on commit 6896d61

Please sign in to comment.