Skip to content

Commit

Permalink
Fix PytorchModel when last layer doesn't support out_features (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbybp authored Jan 17, 2025
1 parent 7496eec commit 2bbff7a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
11 changes: 8 additions & 3 deletions ext/MathOptAIPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,14 @@ function MathOptAI.GrayBox(
torch_model = torch_model.to(device)
J = torch.func.jacrev(torch_model)
H = torch.func.hessian(torch_model)
# TODO(odow): I'm not sure if there is a better way to get the output
# dimension of a torch model object?
output_size(::Any) = PythonCall.pyconvert(Int, torch_model[-1].out_features)
function output_size(x::Vector)
# Get the output size by passing a zero vector through the torch model.
# We do this instead of `torch_model[-1].out_features` as the last layer
# may not support out_features.
z = torch.zeros(length(x))
y = torch_model(z)
return PythonCall.pyconvert(Int, PythonCall.pybuiltins.len(y))
end
function callback(x)
py_x = torch.tensor(collect(x); device = device)
py_value = torch_model(py_x).detach().cpu().numpy()
Expand Down
49 changes: 49 additions & 0 deletions test/test_PythonCall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,55 @@ function test_model_Tanh_vector_GrayBox_hessian()
return
end

function test_model_Sigmoid_last_layer_GrayBox()
dir = mktempdir()
filename = joinpath(dir, "model_Sigmoid_last_layer_GrayBox.pt")
PythonCall.pyexec(
"""
import torch
model = torch.nn.Sequential(
torch.nn.Linear(3, 16),
torch.nn.Sigmoid(),
)
torch.save(model, filename)
""",
@__MODULE__,
(; filename = filename),
)
# Full-space
model = Model(Ipopt.Optimizer)
set_silent(model)
@variable(model, x[i in 1:3] == i)
ml_model = MathOptAI.PytorchModel(filename)
y, formulation =
MathOptAI.add_predictor(model, ml_model, x; gray_box = true)
@test num_variables(model) == 19
@test num_constraints(model; count_variable_in_set_constraints = true) == 19
optimize!(model)
@test is_solved_and_feasible(model)
@test (_evaluate_model(filename, value.(x)), value.(y); atol = 1e-5)
# Reduced-space
model = Model(Ipopt.Optimizer)
set_silent(model)
@variable(model, x[i in 1:3] == i)
ml_model = MathOptAI.PytorchModel(filename)
y, formulation = MathOptAI.add_predictor(
model,
ml_model,
x;
gray_box = true,
reduced_space = true,
)
@test num_variables(model) == 3
@test num_constraints(model; count_variable_in_set_constraints = true) == 3
optimize!(model)
@test is_solved_and_feasible(model)
@test (_evaluate_model(filename, value.(x)), value.(y); atol = 1e-5)
return
end

end # module

TestPythonCallExt.runtests()

0 comments on commit 2bbff7a

Please sign in to comment.