Skip to content

Commit

Permalink
allow each InputSchema to exclude certain parts of itself from the co…
Browse files Browse the repository at this point in the history
…nfig
  • Loading branch information
Adam-D-Lewis committed Mar 9, 2024
1 parent bbdf6f2 commit b2891d1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/_nebari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def write_configuration(
"""Write the nebari configuration file to disk"""
with config_filename.open(mode) as f:
if isinstance(config, pydantic.BaseModel):
config_dict = config.dict()
config_dict = config.write_config()
rev_config_dict = {k: config_dict[k] for k in reversed(config_dict)}
yaml.dump(rev_config_dict, f)
else:
Expand Down
7 changes: 7 additions & 0 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,13 @@ class InputSchema(schema.Base):
azure: Optional[AzureProvider]
digital_ocean: Optional[DigitalOceanProvider]

def exclude_from_config(self):
exclude = set()
for provider in InputSchema.__fields__:
if getattr(self, provider) is None:
exclude.add(provider)
return exclude

@pydantic.root_validator(pre=True)
def check_provider(cls, values):
if "provider" in values:
Expand Down
26 changes: 22 additions & 4 deletions src/nebari/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,30 @@ def read_config(self, config_path: typing.Union[str, Path], **kwargs):
def ordered_stages(self):
return self.get_available_stages()

@property
def ordered_schemas(self):
return [schema.Main] + [_.input_schema for _ in self.ordered_stages if _.input_schema is not None]

@property
def config_schema(self):
classes = [schema.Main] + [
_.input_schema for _ in self.ordered_stages if _.input_schema is not None
]
return type("ConfigSchema", tuple(classes), {})
ordered_schemas = self.ordered_schemas

def write_config(self):
config_exclude = set()
for cls in self._ordered_schemas:
if hasattr(cls, "exclude_from_config"):
new_exclude = cls.exclude_from_config(self)
config_exclude = config_exclude.union(new_exclude)
return self.dict(exclude=config_exclude)


return type(
"ConfigSchema",
tuple(ordered_schemas),
{
"_ordered_schemas": ordered_schemas,
"write_config": write_config,
})


nebari_plugin_manager = NebariPluginManager()

0 comments on commit b2891d1

Please sign in to comment.