Skip to content

Commit

Permalink
Add a thumbnail (image shows up now!)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Dec 12, 2022
1 parent ff65c49 commit 403c191
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions examples/tutorials/plot_5_warm_starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,16 @@ class Net(nn.Sequential):

def __init__(self, n_classes: int = 10):
super().__init__(
nn.LazyConv2d(
32, 3, 1
), # NOTE: `in_channels` is determined in the first forward pass
# NOTE: `in_channels` is determined in the first forward pass
nn.LazyConv2d(32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.25),
nn.Flatten(),
nn.LazyLinear(
128
), # NOTE: `in_features` is determined in the first forward pass
# NOTE: `in_features` is determined in the first forward pass
nn.LazyLinear(128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, n_classes),
Expand Down Expand Up @@ -174,12 +172,10 @@ def test_epoch(model: Net, device: torch.device, test_loader: DataLoader) -> flo
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= num_batches
Expand Down Expand Up @@ -379,4 +375,7 @@ def main(**kwargs):
},
)
fig.show()
fig.write_image("../../docs/src/_static/warm_start_thumbnail.png")
fig

# sphinx_gallery_thumbnail_path = '_static/warm_start_thumbnail.png'

0 comments on commit 403c191

Please sign in to comment.