diff --git a/README.md b/README.md
index 4767e7d..a11c542 100644
--- a/README.md
+++ b/README.md
@@ -15,9 +15,9 @@ pip install torcharc
## Usage
1. specify model architecture in a YAML spec file, e.g. at `spec_filepath = "./example/spec/basic/mlp.yaml"`
-2. `import torcharc`
-3. (optional) if you have custom torch.nn.Module, e.g. `MyModule`, register it with `torcharc.register_nn(MyModule)`
-4. build with: `model = torcharc.build(spec_filepath)`
+2. `import torcharc`.
+ 1. (optional) if you have custom torch.nn.Module, e.g. `MyModule`, register it with `torcharc.register_nn(MyModule)`
+3. build with: `model = torcharc.build(spec_filepath)`
The returned model is a PyTorch `nn.Module`, fully-compatible with `torch.compile`, and mostly compatible with PyTorch JIT script and trace.
@@ -277,6 +277,198 @@ GraphModule(
---
+### Example: MLP (Compact)
+
+Use compact spec that expands into Sequential spec - this is useful for architecture search.
+
+spec file
+
+File: [torcharc/example/spec/compact/mlp.yaml](torcharc/example/spec/compact/mlp.yaml)
+
+```yaml
+# modules:
+# mlp:
+# Sequential:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 64
+# - ReLU:
+# - LazyLinear:
+# out_features: 32
+# - ReLU:
+
+# the above is equivalent to the compact spec below
+
+modules:
+ mlp:
+ compact:
+ layer:
+ type: LazyLinear
+ keys: [out_features]
+ args: [64, 64, 32, 16]
+ postlayer:
+ - ReLU:
+
+graph:
+ input: x
+ modules:
+ mlp: [x]
+ output: mlp
+```
+
+
+
+```python
+model = torcharc.build(torcharc.SPEC_DIR / "compact" / "mlp.yaml")
+
+# Run the model and check the output shape
+x = torch.randn(4, 128)
+y = model(x)
+assert y.shape == (4, 16)
+
+model
+```
+
+model
+
+```
+GraphModule(
+ (mlp): Sequential(
+ (0): Linear(in_features=128, out_features=64, bias=True)
+ (1): ReLU()
+ (2): Linear(in_features=64, out_features=64, bias=True)
+ (3): ReLU()
+ (4): Linear(in_features=64, out_features=32, bias=True)
+ (5): ReLU()
+ (6): Linear(in_features=32, out_features=16, bias=True)
+ (7): ReLU()
+ )
+)
+```
+
+![](images/mlp_compact.png)
+
+
+
+---
+
+### Example: Conv (Compact)
+
+Use compact spec that expands into Sequential spec - this is useful for architecture search.
+
+spec file
+
+File: [torcharc/example/spec/compact/conv.yaml](torcharc/example/spec/compact/conv.yaml)
+
+```yaml
+# modules:
+# conv:
+# Sequential:
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 16
+# kernel_size: 2
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 32
+# kernel_size: 3
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 64
+# kernel_size: 4
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# classifier:
+# Sequential:
+# - Flatten:
+# - LazyLinear:
+# out_features: 10
+
+# the above is equivalent to the compact spec below
+
+modules:
+ conv:
+ compact:
+ prelayer:
+ - LazyBatchNorm2d:
+ layer:
+ type: LazyConv2d
+ keys: [out_channels, kernel_size]
+ args: [[16, 2], [32, 3], [64, 4]]
+ postlayer:
+ - ReLU:
+ - Dropout:
+ p: 0.1
+ classifier:
+ Sequential:
+ - Flatten:
+ - LazyLinear:
+ out_features: 10
+
+graph:
+ input: image
+ modules:
+ conv: [image]
+ classifier: [conv]
+ output: classifier
+```
+
+
+
+```python
+model = torcharc.build(torcharc.SPEC_DIR / "compact" / "conv_classifier.yaml")
+
+# Run the model and check the output shape
+x = torch.randn(4, 1, 28, 28)
+y = model(x)
+assert y.shape == (4, 10)
+
+model
+```
+
+model
+
+```
+GraphModule(
+ (conv): Sequential(
+ (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (1): Conv2d(1, 16, kernel_size=(2, 2), stride=(1, 1))
+ (2): ReLU()
+ (3): Dropout(p=0.1, inplace=False)
+ (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (5): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
+ (6): ReLU()
+ (7): Dropout(p=0.1, inplace=False)
+ (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (9): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1))
+ (10): ReLU()
+ (11): Dropout(p=0.1, inplace=False)
+ )
+ (classifier): Sequential(
+ (0): Flatten(start_dim=1, end_dim=-1)
+ (1): Linear(in_features=30976, out_features=10, bias=True)
+ )
+)
+```
+
+![](images/conv_classifier_compact.png)
+
+
+
+---
+
### Example: Reuse syntax: Stereo Conv
spec file
diff --git a/images/conv_classifier_compact.png b/images/conv_classifier_compact.png
new file mode 100644
index 0000000..5751859
Binary files /dev/null and b/images/conv_classifier_compact.png differ
diff --git a/images/film.png b/images/film.png
index 62fea3c..7b81ef9 100644
Binary files a/images/film.png and b/images/film.png differ
diff --git a/images/mlp_compact.png b/images/mlp_compact.png
new file mode 100644
index 0000000..5044153
Binary files /dev/null and b/images/mlp_compact.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 45067bd..1cc3036 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "torcharc"
-version = "2.0.1"
+version = "2.1.0"
description = "Build PyTorch models by specifying architectures."
readme = "README.md"
requires-python = ">=3.12"
diff --git a/test/example/spec/test_compact_spec.py b/test/example/spec/test_compact_spec.py
new file mode 100644
index 0000000..a6cbb2f
--- /dev/null
+++ b/test/example/spec/test_compact_spec.py
@@ -0,0 +1,34 @@
+import pytest
+import torch
+
+import torcharc
+
+B = 4 # batch size
+
+
+@pytest.mark.parametrize(
+ "spec_file, input_shape, output_shape",
+ [
+ ("mlp.yaml", (B, 128), (B, 16)),
+ ("mlp_classifier.yaml", (B, 128), (B, 10)),
+ ("conv.yaml", (B, 3, 32, 32), (B, 64, 26, 26)),
+ ("conv_classifier.yaml", (B, 3, 32, 32), (B, 10)),
+ ],
+)
+def test_model(spec_file, input_shape, output_shape):
+ # Build the model using torcharc
+ model = torcharc.build(torcharc.SPEC_DIR / "compact" / spec_file)
+ assert isinstance(model, torch.nn.Module)
+
+ # Run the model and check the output shape
+ x = torch.randn(*input_shape)
+ y = model(x)
+ assert y.shape == output_shape
+
+ # Test compatibility with compile, script and trace
+ compiled_model = torch.compile(model)
+ assert compiled_model(x).shape == y.shape
+ scripted_model = torch.jit.script(model)
+ assert scripted_model(x).shape == y.shape
+ traced_model = torch.jit.trace(model, (x))
+ assert traced_model(x).shape == y.shape
diff --git a/test/validator/test_modules.py b/test/validator/test_modules.py
index fa077c1..75ca88c 100644
--- a/test/validator/test_modules.py
+++ b/test/validator/test_modules.py
@@ -1,7 +1,7 @@
import pytest
import torch
-from torcharc.validator.modules import ModuleSpec, NNSpec, SequentialSpec
+from torcharc.validator.modules import CompactSpec, ModuleSpec, NNSpec, SequentialSpec
@pytest.mark.parametrize(
@@ -61,6 +61,61 @@ def test_invalid_sequential_spec(spec_dict):
SequentialSpec(**spec_dict).build()
+@pytest.mark.parametrize(
+ "spec_dict",
+ [
+ {
+ "compact": {
+ "layer": {
+ "type": "LazyLinear",
+ "keys": ["out_features"],
+ "args": [64, 64, 32, 16],
+ },
+ "postlayer": [{"ReLU": {}}],
+ }
+ },
+ {
+ "compact": {
+ "prelayer": [{"LazyBatchNorm2d": {}}],
+ "layer": {
+ "type": "LazyConv2d",
+ "keys": ["out_channels", "kernel_size"],
+ "args": [[16, 2], [32, 3], [64, 4]],
+ },
+ "postlayer": [{"ReLU": {}}, {"Dropout": {"p": 0.1}}],
+ }
+ },
+ ],
+)
+def test_compact_spec(spec_dict):
+ module = CompactSpec(**spec_dict).build()
+ assert isinstance(module, torch.nn.Module)
+
+
+@pytest.mark.parametrize(
+ "spec_dict",
+ [
+ # multi-key
+ {
+ "compact": {
+ "layer": {
+ "type": "LazyLinear",
+ "keys": ["out_features"],
+ "args": [64, 64, 32, 16],
+ },
+ "postlayer": [{"ReLU": {}}],
+ },
+ "ReLU": {},
+ },
+ # non-compact
+ {"LazyLinear": {"out_features": 64}},
+ ],
+)
+def test_invalid_compact_spec(spec_dict):
+ with pytest.raises(Exception):
+ CompactSpec(**spec_dict).build()
+
+
@pytest.mark.parametrize(
"spec_dict",
[
diff --git a/torcharc/example/spec/compact/conv.yaml b/torcharc/example/spec/compact/conv.yaml
new file mode 100644
index 0000000..fdfa93f
--- /dev/null
+++ b/torcharc/example/spec/compact/conv.yaml
@@ -0,0 +1,46 @@
+# modules:
+# conv:
+# Sequential:
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 16
+# kernel_size: 2
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 32
+# kernel_size: 3
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 64
+# kernel_size: 4
+# - ReLU:
+# - Dropout:
+# p: 0.1
+
+# the above is equivalent to the compact spec below
+
+modules:
+ conv:
+ compact:
+ prelayer:
+ - LazyBatchNorm2d:
+ layer:
+ type: LazyConv2d
+ keys: [out_channels, kernel_size]
+ args: [[16, 2], [32, 3], [64, 4]]
+ postlayer:
+ - ReLU:
+ - Dropout:
+ p: 0.1
+
+graph:
+ input: image
+ modules:
+ conv: [image]
+ output: conv
diff --git a/torcharc/example/spec/compact/conv_classifier.yaml b/torcharc/example/spec/compact/conv_classifier.yaml
new file mode 100644
index 0000000..c07dacd
--- /dev/null
+++ b/torcharc/example/spec/compact/conv_classifier.yaml
@@ -0,0 +1,57 @@
+# modules:
+# conv:
+# Sequential:
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 16
+# kernel_size: 2
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 32
+# kernel_size: 3
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# - LazyBatchNorm2d:
+# - LazyConv2d:
+# out_channels: 64
+# kernel_size: 4
+# - ReLU:
+# - Dropout:
+# p: 0.1
+# classifier:
+# Sequential:
+# - Flatten:
+# - LazyLinear:
+# out_features: 10
+
+# the above is equivalent to the compact spec below
+
+modules:
+ conv:
+ compact:
+ prelayer:
+ - LazyBatchNorm2d:
+ layer:
+ type: LazyConv2d
+ keys: [out_channels, kernel_size]
+ args: [[16, 2], [32, 3], [64, 4]]
+ postlayer:
+ - ReLU:
+ - Dropout:
+ p: 0.1
+ classifier:
+ Sequential:
+ - Flatten:
+ - LazyLinear:
+ out_features: 10
+
+graph:
+ input: image
+ modules:
+ conv: [image]
+ classifier: [conv]
+ output: classifier
diff --git a/torcharc/example/spec/compact/mlp.yaml b/torcharc/example/spec/compact/mlp.yaml
new file mode 100644
index 0000000..2298740
--- /dev/null
+++ b/torcharc/example/spec/compact/mlp.yaml
@@ -0,0 +1,33 @@
+# modules:
+# mlp:
+# Sequential:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 64
+# - ReLU:
+# - LazyLinear:
+# out_features: 32
+# - ReLU:
+
+# the above is equivalent to the compact spec below
+
+modules:
+ mlp:
+ compact:
+ layer:
+ type: LazyLinear
+ keys: [out_features]
+ args: [64, 64, 32, 16]
+ postlayer:
+ - ReLU:
+
+graph:
+ input: x
+ modules:
+ mlp: [x]
+ output: mlp
diff --git a/torcharc/example/spec/compact/mlp_classifier.yaml b/torcharc/example/spec/compact/mlp_classifier.yaml
new file mode 100644
index 0000000..273d348
--- /dev/null
+++ b/torcharc/example/spec/compact/mlp_classifier.yaml
@@ -0,0 +1,41 @@
+# modules:
+# mlp:
+# Sequential:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 128
+# - ReLU:
+# - LazyLinear:
+# out_features: 64
+# - ReLU:
+# - LazyLinear:
+# out_features: 32
+# - ReLU:
+# classifier:
+# LazyLinear:
+# out_features: 10
+
+# the above is equivalent to the compact spec below
+
+modules:
+ mlp:
+ compact:
+ layer:
+ type: LazyLinear
+ keys: [out_features]
+ args: [64, 64, 32, 16]
+ postlayer:
+ - ReLU:
+ classifier:
+ LazyLinear:
+ out_features: 10
+
+graph:
+ input: x
+ modules:
+ # mlp: [x]
+ # classifier: mlp
+ classifier: [x]
+ output: classifier
diff --git a/torcharc/validator/graph.py b/torcharc/validator/graph.py
index 196e8d9..cf405ba 100644
--- a/torcharc/validator/graph.py
+++ b/torcharc/validator/graph.py
@@ -9,7 +9,7 @@ class GraphSpec(BaseModel):
"""
input: str | list[str] = Field(
- description=" input placeholder nodes of fx.Graph",
+ description="Input placeholder nodes of fx.Graph",
examples=["x", ["x_0, x_1"]],
)
modules: dict[str, list[str | list[str]] | dict[str, str]] = Field(
diff --git a/torcharc/validator/modules.py b/torcharc/validator/modules.py
index df0d60a..95e1fa1 100644
--- a/torcharc/validator/modules.py
+++ b/torcharc/validator/modules.py
@@ -1,5 +1,5 @@
# Pydantic validation for modules spec
-from pydantic import Field, RootModel, field_validator
+from pydantic import BaseModel, Field, RootModel, field_validator
from torch import nn
@@ -91,9 +91,156 @@ def build(self) -> nn.Sequential:
return nn.Sequential(*[nn_spec.build() for nn_spec in nn_specs])
+class CompactLayerSpec(BaseModel):
+ """
+ Spec to compactly specify multiple layers with common kwargs keys and values list for the layers.
+ The following
+
+ type:
+ keys:
+ args: [class kwargs values]
+
+ expands into
+
+ [{: {: }}, ..., {: {: }}]
+ """
+
+ type: str = Field(
+ description="Name of a torch.nn class", examples=["LazyLinear", "LazyConv2d"]
+ )
+ keys: list[str] = Field(
+ description="The class' kwargs keys, to be expanded and zipped with args.",
+ examples=[["out_features"], ["out_channels", "kernel_size"]],
+ )
+ args: list[list] = Field(
+ description="The class' kwargs values for each layer. For convenience this will be casted to list of lists to allow a list of singleton values.",
+ examples=[[64, 64, 32, 16], [[16, 2], [32, 3], [64, 4]]],
+ )
+
+ @field_validator("args", mode="before")
+ def cast_list_of_list(value: list) -> list[list]:
+ return [v if isinstance(v, list) else [v] for v in value]
+
+
+class CompactValueSpec(BaseModel):
+ """Intermediate spec defining the values of CompactSpec"""
+
+ prelayer: list[NNSpec] | None = Field(
+ None,
+ description="The optional list of NNSpec layers that repeat before the mid layer.",
+ )
+ layer: CompactLayerSpec = Field(
+ description="The mid layer to be expanded, wrapped between prelayer and postlayer, and repeated."
+ )
+ postlayer: list[NNSpec] | None = Field(
+ None,
+ description="The optional list of NNSpec layers that repeat after the mid layer.",
+ )
+
+
+class CompactSpec(RootModel):
+ """
+ Higher level compact spec that expands into Sequential spec. This is useful for architecture search.
+ Compact spec has the format:
+
+ compact:
+ prelayer: [NNSpec]
+ layer:
+ type:
+ keys:
+ args: [class kwargs values]
+ postlayer: [NNSpec]
+
+ E.g.
+ compact:
+ layer:
+ type: LazyLinear
+ keys: [out_features]
+ args: [64, 64, 32, 16]
+ postlayer:
+ - ReLU:
+
+ E.g.
+ compact:
+ prelayer:
+ - LazyBatchNorm2d:
+ layer:
+ type: LazyConv2d
+ keys: [out_channels, kernel_size]
+ args: [[16, 2], [32, 3], [64, 4]]
+ postlayer:
+ - ReLU:
+ - Dropout:
+ p: 0.1
+ """
+
+ root: dict[str, CompactValueSpec] = Field(
+ description="Higher level compact spec that expands into Sequential spec.",
+ examples=[
+ {
+ "compact": {
+ "layer": {
+ "type": "LazyLinear",
+ "keys": ["out_features"],
+ "args": [64, 64, 32, 16],
+ },
+ "postlayer": [{"ReLU": {}}],
+ }
+ },
+ {
+ "compact": {
+ "prelayer": [{"LazyBatchNorm2d": {}}],
+ "layer": {
+ "type": "LazyConv2d",
+ "keys": ["out_channels", "kernel_size"],
+ "args": [[16, 2], [32, 3], [64, 4]],
+ },
+ "postlayer": [{"ReLU": {}}, {"Dropout": {"p": 0.1}}],
+ }
+ },
+ ],
+ )
+
+ @field_validator("root", mode="before")
+ def is_single_key_dict(value: dict) -> dict:
+ return NNSpec.is_single_key_dict(value)
+
+ @field_validator("root", mode="before")
+ def key_is_compact(value: dict) -> dict:
+ assert (
+ next(iter(value)) == "compact"
+ ), "Key must be 'compact' if using CompactSpec."
+ return value
+
+ def __expand_spec(self, compact_layer: dict) -> list[dict]:
+ class_name = compact_layer["type"]
+ keys = compact_layer["keys"]
+ args = compact_layer["args"]
+ nn_specs = []
+ for vals in args:
+ nn_spec = {class_name: dict(zip(keys, vals))}
+ nn_specs.append(nn_spec)
+ return nn_specs
+
+ def expand_to_sequential_spec(self) -> SequentialSpec:
+ compact_spec = next(iter(self.root.values())).model_dump()
+ prelayer = compact_spec.get("prelayer")
+ postlayer = compact_spec.get("postlayer")
+ nn_specs = []
+ for midlayer in self.__expand_spec(compact_spec["layer"]):
+ nn_specs.extend(prelayer) if prelayer else True
+ nn_specs.append(midlayer)
+ nn_specs.extend(postlayer) if postlayer else True
+ return SequentialSpec(**{"Sequential": nn_specs})
+
+ def build(self) -> nn.Sequential:
+ """Build nn.Sequential from compact spec expanded into sequential spec"""
+ return self.expand_to_sequential_spec().build()
+
+
class ModuleSpec(RootModel):
"""
- Higher level module spec where value can be either NNSpec or SequentialSpec.
+ Higher level module spec where value can be NNSpec, SequentialSpec, or CompactSpec.
E.g. (plain NN)
Linear:
in_features: 10
@@ -108,13 +255,33 @@ class ModuleSpec(RootModel):
- Linear:
in_features: 64
out_features: 10
+
+ E.g. (compact)
+ compact:
+ layer:
+ type: LazyLinear
+ keys: [out_features]
+ args: [64, 64, 32, 16]
+ postlayer:
+ - ReLU:
"""
- root: SequentialSpec | NNSpec = Field(
- description="Higher level module spec where value can be either NNSpec or SequentialSpec.",
+ root: NNSpec | SequentialSpec | CompactSpec = Field(
+ description="Higher level module spec where value can be NNSpec, SequentialSpec, or CompactSpec.",
examples=[
{"Linear": {"in_features": 128, "out_features": 64}},
{"Sequential": [{"Linear": {"in_features": 128, "out_features": 64}}]},
+ {
+ "compact": {
+ "prelayer": [{"LazyBatchNorm2d": {}}],
+ "layer": {
+ "type": "LazyConv2d",
+ "keys": ["out_channels", "kernel_size"],
+ "args": [[16, 2], [32, 3], [64, 4]],
+ },
+ "postlayer": [{"ReLU": {}}, {"Dropout": {"p": 0.1}}],
+ }
+ },
],
)
diff --git a/uv.lock b/uv.lock
index 028b69e..8446497 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1354,7 +1354,7 @@ wheels = [
[[package]]
name = "torcharc"
-version = "2.0.0"
+version = "2.1.0"
source = { editable = "." }
dependencies = [
{ name = "pydantic" },