Skip to content

Commit

Permalink
Overhaul action space end environment
Browse files Browse the repository at this point in the history
- Create a valid schedule in the beginnig
- Fix number of actions
- Restrict the possibilties of cutting
  • Loading branch information
Gistbatch committed Mar 29, 2024
1 parent 5eed017 commit 721b820
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 138 deletions.
27 changes: 3 additions & 24 deletions src/scheduling/learning/action_space.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,18 @@
"""Custom action space for the scheduling environment."""

from gymnasium.spaces import Discrete, Dict, MultiDiscrete
from qiskit import QuantumCircuit

from src.scheduling.common import Schedule


class ActionSpace(Dict):
"""The action space for the scheduling environment.
It contains the spaces for the following actions:"""

def __init__(self, circuits: list[QuantumCircuit], schedule: Schedule) -> None:
n_circuits = len(circuits)
n_buckets = sum(
len(machine.buckets) + 1 for machine in schedule.machines
) # +1 for allowing new bucket
def __init__(self, n_circuits: int, n_buckets: int) -> None:

# 0: cut, 1: move, 2: swap ## removed 1: combine
super().__init__(
{
"action": Discrete(3),
"action": Discrete(4),
"params": MultiDiscrete([n_circuits, n_circuits, n_buckets]),
}
)

def update_actions(self, schedule: Schedule) -> None:
n_circuits = 0
n_buckets = 0

for machine in schedule.machines:
n_circuits += sum(len(bucket.jobs) for bucket in machine.buckets)
n_buckets += len(machine.buckets) + 1 # +1 for allowing new bucket
self.spaces["params"] = MultiDiscrete([n_circuits, n_circuits, n_buckets])

def enable_terminate(self) -> None:
self.spaces["action"] = Discrete(4) # 0: cut, 1: move, 2: swap, 3: terminate

def disable_terminate(self) -> None:
self.spaces["action"] = Discrete(3)
203 changes: 89 additions & 114 deletions src/scheduling/learning/environment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from enum import Enum
from typing import Any

Expand All @@ -10,14 +12,15 @@
from src.common import UserCircuit
from src.provider import Accelerator
from src.scheduling.common import (
Machine,
Schedule,
Bucket,
CircuitProxy,
evaluate_solution,
convert_circuits,
cut_proxies,
)
from src.scheduling.heuristics.initialize import _better_partitioning, _bin_schedule

from .action_space import ActionSpace


Expand Down Expand Up @@ -55,36 +58,14 @@ def __init__(
self.accelerators = accelerators
self.circuits: list[QuantumCircuit | UserCircuit] = circuits
self.noise_weight = noise_weight
self._schedule = Schedule([], np.inf)
self.jobs = []
self._n_jobs = self._init_schedule()

self._schedule = Schedule(
[
(
Machine(
accelerator.qubits,
str(accelerator.uuid),
[],
len(accelerator.queue),
)
if accelerator is not None
else Machine(0, "None", [], 0)
)
for accelerator in accelerators
],
np.inf,
) # Initialize with empty schedules for each device
for circuit in self.circuits:
[proxy] = convert_circuits([circuit], accelerators)
choice = next(
(
idx
for idx, machine in enumerate(self._schedule.machines)
if machine.id == proxy.preselection
),
np.random.choice(len(self._schedule.machines)),
)
self._schedule.machines[choice].buckets.append(Bucket([proxy]))
# Define the action and observation spaces
self._action_space = ActionSpace(circuits, self._schedule)
self._action_space = ActionSpace(
len(self.circuits), self._n_jobs * len(self._schedule.machines)
)
self.action_space = spaces.flatten_space(self._action_space)
self.observation_space = spaces.Dict(
{
Expand All @@ -98,27 +79,28 @@ def step(self, action: np.ndarray) -> tuple[Any, float, bool, bool, dict[str, An
truncated = False
penalty = 1.0 # multiplcative penalty for invalid cuts
# Perform the specified action and update the schedule
logging.info("Binary action %s", len(action))
dict_action = self._unflatten_action(action)
logging.info("%d steps.", self.steps)
logging.info("Action: %s", dict_action)
match dict_action["action"]:
case Actions.CUT_CIRCUIT:
case Actions.CUT_CIRCUIT.value:
logging.info("Action: %s", Actions.CUT_CIRCUIT)
penalty = self._cut(*dict_action["params"])
# case Actions.COMBINE_CIRCUIT:
# self._combine(*dict_action["params"])
case Actions.MOVE_CIRCUIT:
self._move(*dict_action["params"])
case Actions.SWAP_CIRCUITS:
self._swap(*dict_action["params"])
case Actions.TERMINATE:
case Actions.MOVE_CIRCUIT.value:
logging.info("Action: %s", Actions.MOVE_CIRCUIT)
penalty = self._move(*dict_action["params"])
case Actions.SWAP_CIRCUITS.value:
logging.info("Action: %s", Actions.SWAP_CIRCUITS)
penalty = self._swap(*dict_action["params"])
case Actions.TERMINATE.value:
logging.info("Action: %s", Actions.TERMINATE)
terminated = True

# Calculate the completion time and noise based on the updated schedule
# Return the new schedule, completion time, noise, and whether the task is done
self._action_space.update_actions(self._schedule)
if self._schedule.is_feasible():
self._action_space.enable_terminate()
else:
self._action_space.disable_terminate()
self.action_space = spaces.flatten_space(self._action_space)
# self._action_space.update_actions(self._schedule)

if self.steps >= self.max_steps:
truncated = True
self.steps += 1
Expand All @@ -128,9 +110,10 @@ def step(self, action: np.ndarray) -> tuple[Any, float, bool, bool, dict[str, An
self._calculate_reward(float(obs["makespan"]), float(obs["noise"]))
* penalty
)
logging.info("Feasible: %s", self._schedule.is_feasible())
if terminated and not self._schedule.is_feasible():
reward = -np.inf

logging.info("Reward: %s", reward)
return (
obs,
reward,
Expand Down Expand Up @@ -164,33 +147,12 @@ def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
super().reset(seed=seed)
schedule = Schedule(
[
(
Machine(
accelerator.qubits,
str(accelerator.uuid),
[],
len(accelerator.queue),
)
if accelerator is not None
else Machine(0, "None", [], 0)
)
for accelerator in self.accelerators
],
np.inf,
)
for circuit in self.circuits:
[proxy] = convert_circuits([circuit], self.accelerators)
choice = next(
(
idx
for idx, machine in enumerate(self._schedule.machines)
if machine.id == proxy.preselection
),
np.random.choice(len(schedule.machines)),
)
schedule.machines[choice].buckets.append(Bucket([proxy]))
self._schedule = Schedule(_bin_schedule(self.jobs, self.accelerators), np.inf)
for machine in self._schedule.machines:
# padd
machine.buckets += [
Bucket(jobs=[]) for _ in range(self._n_jobs - len(machine.buckets))
]

obs = self._get_observation()
info = self._get_info()
Expand Down Expand Up @@ -220,6 +182,7 @@ def _get_info(self) -> dict[str, Any]:

def _calculate_reward(self, completion_time: float, expected_noise: float) -> float:
# Calculate the reward based on the completion time and expected noise
# TODO: possible 1/this to have positive reward
return -completion_time + (expected_noise * self.noise_weight)

def _cut(self, index: int, cut_index: int, *_) -> float:
Expand All @@ -230,68 +193,82 @@ def _cut(self, index: int, cut_index: int, *_) -> float:
job = self._schedule.machines[machine_id].buckets[bucket_id].jobs.pop(job_id)

if job.num_qubits < 3 or job.num_qubits - cut_index < 2 or cut_index < 2:
self._schedule.machines[machine_id].buckets[bucket_id].jobs.insert(
job_id, job
)
logging.info("Invalid cut")
return self.penalty
logging.info(
"Cutting circuit with %s qubits at index %d", job.num_qubits, cut_index
)
new_jobs = cut_proxies(
[job],
[[0] * cut_index + [1] * (job.circuit.num_qubits - cut_index)],
[[0] * cut_index + [1] * (job.num_qubits - cut_index)],
)
self._schedule.machines[machine_id].buckets[bucket_id].jobs += new_jobs
return 1

# def _combine(self, index1: int, index2: int, *_) -> None:
# # Combine two circuits into a single larger circuit
# # remove the two circuits from the machine and add the larger circuit
# # adds to the bucket of the first circuit
# (machine_id1, bucket_id1, job_id1) = _find_job(self._schedule, index1)
# (machine_id2, bucket_id2, job_id2) = _find_job(self._schedule, index2)
# job_1 = (
# self._schedule.machines[machine_id1].buckets[bucket_id1].jobs.pop(job_id1)
# )
# job_2 = (
# self._schedule.machines[machine_id2].buckets[bucket_id2].jobs.pop(job_id2)
# )
# # TODO keep track of origins / indices properly
# combined_circuit = CircuitProxy(
# origin=job_1.origin,
# processing_time=(
# job_1.processing_time
# if job_1.processing_time > job_2.processing_time
# else job_2.processing_time
# ),
# num_qubits=job_1.num_qubits + job_2.num_qubits,
# uuid=job_1.uuid,
# indices=job_1.indices + job_2.indices,
# n_shots=job_1.n_shots if job_1.n_shots > job_2.n_shots else job_2.n_shots,
# )

# self._schedule.machines[machine_id1].buckets[bucket_id1].jobs.append(
# combined_circuit
# )

def _move(self, index1: int, _: int, move_to: int) -> None:
def _move(self, index1: int, _: int, move_to: int) -> int:
# Move a circuit to a new bucket
(machine_id, bucket_id, job_id) = _find_job(self._schedule, index1)
(new_machine_id, new_bucket_id) = _find_bucket(self._schedule, move_to)
job = self._schedule.machines[machine_id].buckets[bucket_id].jobs.pop(job_id)
self._schedule.machines[new_machine_id].buckets[new_bucket_id].jobs.append(job)
try:
job = (
self._schedule.machines[machine_id].buckets[bucket_id].jobs.pop(job_id)
)
self._schedule.machines[new_machine_id].buckets[new_bucket_id].jobs.append(
job
)
return 1
except IndexError:
return self.penalty

def _swap(self, index1: int, index2: int, *_) -> None:
def _swap(self, index1: int, index2: int, *_) -> int:
(machine_id1, bucket_id1, job_id1) = _find_job(self._schedule, index1)
(machine_id2, bucket_id2, job_id2) = _find_job(self._schedule, index2)

(
self._schedule.machines[machine_id1].buckets[bucket_id1].jobs[job_id1],
self._schedule.machines[machine_id2].buckets[bucket_id2].jobs[job_id2],
) = (
self._schedule.machines[machine_id2].buckets[bucket_id2].jobs[job_id2],
self._schedule.machines[machine_id1].buckets[bucket_id1].jobs[job_id1],
try:
(
self._schedule.machines[machine_id1].buckets[bucket_id1].jobs[job_id1],
self._schedule.machines[machine_id2].buckets[bucket_id2].jobs[job_id2],
) = (
self._schedule.machines[machine_id2].buckets[bucket_id2].jobs[job_id2],
self._schedule.machines[machine_id1].buckets[bucket_id1].jobs[job_id1],
)
return 1
except IndexError:
return self.penalty

def _init_schedule(
self,
) -> int:
quantum_circuits = [
circuit if isinstance(circuit, QuantumCircuit) else circuit.circuit
for circuit in self.circuits
]

partitions = _better_partitioning(quantum_circuits, self.accelerators)
self.jobs: list[CircuitProxy] = convert_circuits(
self.circuits, self.accelerators, partitions
)
n_jobs = len(self.jobs)
self._schedule = Schedule(_bin_schedule(self.jobs, self.accelerators), np.inf)
for machine in self._schedule.machines:
# padd
machine.buckets += [
Bucket(jobs=[]) for _ in range(n_jobs - len(machine.buckets))
]

return n_jobs


def _find_job(schedule: Schedule, index: int) -> tuple[int, int, int]:
count = 0
for machine_id, machine in enumerate(schedule.machines):
for bucket_id, bucket in enumerate(machine.buckets):
if len(bucket.jobs) == 0:
count += 1
continue
for job_id, _ in enumerate(bucket.jobs):
if count == index:
return machine_id, bucket_id, job_id
Expand All @@ -302,10 +279,8 @@ def _find_job(schedule: Schedule, index: int) -> tuple[int, int, int]:
def _find_bucket(schedule: Schedule, index: int) -> tuple[int, int]:
count = 0
for machine_id, machine in enumerate(schedule.machines):
machine.buckets.append(Bucket([])) # allow to create new bucket
for bucket_id, _ in enumerate(machine.buckets):
if count == index:
return machine_id, bucket_id
count += 1
machine.buckets.pop()
raise ValueError("Index out of range")

0 comments on commit 721b820

Please sign in to comment.