Skip to content

Commit

Permalink
Merge pull request #87 from BrunoRosendo/fix/time-code-duplication
Browse files Browse the repository at this point in the history
Extracted time measurement to a function
  • Loading branch information
BrunoRosendo authored May 14, 2024
2 parents 2dea482 + 6bec399 commit cd83459
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 41 deletions.
8 changes: 5 additions & 3 deletions src/model/adapter/DWaveAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ def expression_to_tuples(self, expr: Expr) -> list[tuple]:
else:
raise ValueError("Invalid expression type.")

def _linear_expression_to_tuples(self, expr: LinearExpr) -> list[tuple[str, int]]:
@staticmethod
def _linear_expression_to_tuples(expr: LinearExpr) -> list[tuple[str, int]]:
"""
Convert a linear expression to a list of tuples.
The first element of the tuple is the variable name and the second is the coefficient.
"""

return [(v.name, coefficient) for (v, coefficient) in expr.iter_terms()]

def _quadratic_expression_to_tuples(self, expr: QuadExpr) -> list[tuple]:
@staticmethod
def _quadratic_expression_to_tuples(expr: QuadExpr) -> list[tuple]:
"""
Convert a quadratic expression to a list of tuples, including the linear part.
The first two elements of the tuple are the variable names and the third is the coefficient.
Expand All @@ -124,6 +126,6 @@ def _quadratic_expression_to_tuples(self, expr: QuadExpr) -> list[tuple]:
(pair.first.name, pair.second.name, coefficient)
for (pair, coefficient) in expr.iter_quads()
]
linear_part = self._linear_expression_to_tuples(expr.linear_part)
linear_part = DWaveAdapter._linear_expression_to_tuples(expr.linear_part)

return quad_part + linear_part
17 changes: 8 additions & 9 deletions src/solver/ClassicSolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time
from typing import Any

from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
if self.use_rpp:
self.add_dummy_depot()

def _solve_cvrp(self) -> any:
def _solve_cvrp(self) -> Any:
"""
Solve the CVRP using Google's OR Tools.
"""
Expand All @@ -71,25 +71,24 @@ def _solve_cvrp(self) -> any:
self.set_pickup_and_deliveries()

search_parameters = self.get_search_parameters()

start_time = time.perf_counter_ns()
or_solution = self.routing.SolveWithParameters(search_parameters)
self.run_time = (time.perf_counter_ns() - start_time) // 1000
or_solution, self.run_time = self.measure_time(
self.routing.SolveWithParameters, search_parameters
)

if or_solution is None:
raise Exception("The solution is infeasible, aborting!")

return or_solution

def distance_callback(self, from_index: any, to_index: any) -> int:
def distance_callback(self, from_index: Any, to_index: Any) -> int:
"""Returns the distance between the two nodes."""

# Convert from index to distance matrix NodeIndex.
from_node = self.manager.IndexToNode(from_index)
to_node = self.manager.IndexToNode(to_index)
return self.distance_matrix[from_node][to_node]

def demand_callback(self, from_index: any) -> int:
def demand_callback(self, from_index: Any) -> int:
"""Returns the demand of the node."""

# Convert from index to demands NodeIndex.
Expand Down Expand Up @@ -202,7 +201,7 @@ def remove_unused_locations(
for trip in trips
]

def _convert_solution(self, result: any, local_run_time: float) -> VRPSolution:
def _convert_solution(self, result: Any, local_run_time: float) -> VRPSolution:
"""Converts OR-Tools result to CVRP solution."""

routes = []
Expand Down
29 changes: 21 additions & 8 deletions src/solver/VRPSolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from abc import ABC, abstractmethod
from typing import Callable, Any

from src.model.VRP import VRP
from src.model.VRPSolution import VRPSolution
Expand Down Expand Up @@ -75,14 +76,14 @@ def compute_distance(self) -> list[list[int]]:
return distance_matrix

@abstractmethod
def _solve_cvrp(self) -> any:
def _solve_cvrp(self) -> Any:
"""
Solve the CVRP with a specific solver.
"""
pass

@abstractmethod
def _convert_solution(self, result: any, local_run_time: float) -> VRPSolution:
def _convert_solution(self, result: Any, local_run_time: float) -> VRPSolution:
"""
Convert the result from the solver to a CVRP solution.
"""
Expand All @@ -93,12 +94,7 @@ def solve(self) -> VRPSolution:
Solve the CVRP.
"""

start_time = time.perf_counter_ns()
result = self._solve_cvrp()
execution_time = (
time.perf_counter_ns() - start_time
) // 1000 # Convert to microseconds

result, execution_time = self.measure_time(self._solve_cvrp)
return self._convert_solution(result, execution_time)

@abstractmethod
Expand All @@ -107,3 +103,20 @@ def get_model(self) -> VRP:
Get a VRP instance of the model.
"""
pass

@staticmethod
def measure_time(
fun: Callable[..., Any], *args: Any, **kwargs: Any
) -> tuple[Any, int]:
"""
Measure the execution time of a function.
Returns the result and the execution time in microseconds.
"""

start_time = time.perf_counter_ns()
result = fun(*args, **kwargs)
execution_time = (
time.perf_counter_ns() - start_time
) // 1000 # Convert to microseconds

return result, execution_time
18 changes: 7 additions & 11 deletions src/solver/qubo/CplexSolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import time
from typing import Any

from docplex.util.status import JobSolveStatus
from numpy import ndarray
Expand Down Expand Up @@ -101,7 +101,8 @@ def _solve_cvrp(self) -> OptimizationResult:
self.check_feasibility(result)
return result

def convert_quadratic_program(self, qp: QuadraticProgram) -> QuadraticProgram:
@staticmethod
def convert_quadratic_program(qp: QuadraticProgram) -> QuadraticProgram:
"""
Convert the quadratic program to a canonic formulation, using the Qiskit converters.
"""
Expand All @@ -126,10 +127,7 @@ def solve_classic(self, qp: QuadraticProgram) -> OptimizationResult:
"""

optimizer = CplexOptimizer(disp=self.track_progress)

start_time = time.perf_counter_ns()
result = optimizer.solve(qp)
self.run_time = (time.perf_counter_ns() - start_time) // 1000
result, self.run_time = self.measure_time(optimizer.solve, qp)

return result

Expand All @@ -151,14 +149,12 @@ def solve_qubo(self, qp: QuadraticProgram) -> OptimizationResult:
else:
optimizer = MinimumEigenOptimizer(qaoa)

start_time = time.perf_counter_ns()
result = optimizer.solve(qp)
self.run_time = (time.perf_counter_ns() - start_time) // 1000

result, self.run_time = self.measure_time(optimizer.solve, qp)
return result

@staticmethod
def qaoa_callback(
self, iter_num: int, ansatz: ndarray, objective: float, metadata: dict[str, any]
iter_num: int, ansatz: ndarray, objective: float, metadata: dict[str, Any]
):
print(f"Iteration {iter_num}: {objective.real} objective")

Expand Down
16 changes: 6 additions & 10 deletions src/solver/qubo/DWaveSolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
from logging import warning

from dimod import (
Expand Down Expand Up @@ -141,7 +140,8 @@ def is_sample_feasible(self, s):

return s.is_feasible

def is_cqm_sampler(self, sampler: Sampler) -> bool:
@staticmethod
def is_cqm_sampler(sampler: Sampler) -> bool:
"""
Check if the sampler is a CQM sampler.
"""
Expand Down Expand Up @@ -178,10 +178,9 @@ def sample_cqm(self) -> SampleSet:
Sample the CQM using the selected sampler and time limit.
"""
kwargs = {"time_limit": self.time_limit} if self.time_limit else {}

start_time = time.perf_counter_ns()
result = self.sampler.sample_cqm(self.cqm, **kwargs)
self.run_time = (time.perf_counter_ns() - start_time) // 1000
result, self.run_time = self.measure_time(
self.sampler.sample_cqm, self.cqm, **kwargs
)

return result

Expand All @@ -197,8 +196,5 @@ def sample_bqm(
if self.time_limit:
kwargs["time_limit"] = self.time_limit

start_time = time.perf_counter_ns()
result = sampler.sample(bqm, **kwargs)
self.run_time = (time.perf_counter_ns() - start_time) // 1000

result, self.run_time = self.measure_time(sampler.sample, bqm, **kwargs)
return result

0 comments on commit cd83459

Please sign in to comment.