Skip to content

Commit

Permalink
feat: local detuning validation for ahs (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbeCoull authored Apr 16, 2024
1 parent 3915b30 commit 9e872f0
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/braket/analog_hamiltonian_simulator/rydberg/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
MAGNITUDE_PATTERN_VALUE_MIN = 0.0
MAGNITUDE_PATTERN_VALUE_MAX = 1.0

# Maximum net detuning for all atoms
MAX_NET_DETUNING = 2e8


def capabilities_constants() -> CapabilitiesConstants:
return CapabilitiesConstants(
Expand All @@ -60,4 +63,5 @@ def capabilities_constants() -> CapabilitiesConstants:
LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX=LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX,
MAGNITUDE_PATTERN_VALUE_MIN=MAGNITUDE_PATTERN_VALUE_MIN,
MAGNITUDE_PATTERN_VALUE_MAX=MAGNITUDE_PATTERN_VALUE_MAX,
MAX_NET_DETUNING=MAX_NET_DETUNING,
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ class CapabilitiesConstants(BaseModel):

MAGNITUDE_PATTERN_VALUE_MIN: Decimal
MAGNITUDE_PATTERN_VALUE_MAX: Decimal
MAX_NET_DETUNING: Decimal
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

import warnings
from decimal import Decimal
from typing import List

import numpy as np
from braket.ir.ahs.program_v1 import Program

from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import (
CapabilitiesConstants,
)


def validate_value_range_with_warning(
Expand All @@ -36,3 +44,63 @@ def validate_value_range_with_warning(
f"[{min_value}, {max_value}]. The values should be specified in SI units."
)
break # Only one warning messasge will be issued


def validate_net_detuning_with_warning(
program: Program,
time_points: np.ndarray,
global_detuning_coefs: np.ndarray,
local_detuning_patterns: List,
local_detuning_coefs: np.ndarray,
capabilities: CapabilitiesConstants,
) -> Program:
"""
Validate the given program for the net detuning of all the atoms at all time points
Args:
program (Program): The given program
time_points (np.ndarray): The time points for both global and local detunings
global_detuning_coefs (np.ndarray): The values of global detuning
local_detuning_patterns (List): The pattern of local detuning
local_detuning_coefs (np.ndarray): The values of local detuning
capabilities (CapabilitiesConstants): The capability constants
Returns:
program (Program): The given program
"""

for time_ind, time in enumerate(time_points):

# Get the contributions from all the global detunings
# (there could be multiple global driving fields) at the time point
values_global_detuning = sum(
[detuning_coef[time_ind] for detuning_coef in global_detuning_coefs]
)

for atom_index in range(len(local_detuning_patterns[0])):
# Get the contributions from local detuning at the time point
values_local_detuning = sum(
[
shift_coef[time_ind] * float(detuning_pattern[atom_index])
for detuning_pattern, shift_coef in zip(
local_detuning_patterns, local_detuning_coefs
)
]
)

# The net detuning is the sum of both the global and local detunings
detuning_to_check = np.real(values_local_detuning + values_global_detuning)

# Issue a warning if the absolute value of the net detuning is
# beyond MAX_NET_DETUNING
if abs(detuning_to_check) > capabilities.MAX_NET_DETUNING:
warnings.warn(
f"Atom {atom_index} has net detuning {detuning_to_check} rad/s "
f"at time {time} seconds, which is outside the typical range "
f"[{-capabilities.MAX_NET_DETUNING}, {capabilities.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)

# Return immediately if there is an atom has net detuning
# exceeding MAX_NET_DETUNING at a time point
return program
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from copy import deepcopy

from braket.ir.ahs.program_v1 import Program
from pydantic.v1 import root_validator

from braket.analog_hamiltonian_simulator.rydberg.rydberg_simulator_helpers import _get_coefs
from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import (
CapabilitiesConstants,
)
from braket.analog_hamiltonian_simulator.rydberg.validators.field_validator_util import (
validate_net_detuning_with_warning,
)


class ProgramValidator(Program):
Expand All @@ -34,7 +40,56 @@ def local_detuning_pattern_has_the_same_length_as_atom_array_sites(cls, values):
pattern_size = len(local_detuning["magnitude"]["pattern"])
if num_sites != pattern_size:
raise ValueError(
f"The length of pattern ({pattern_size}) of shifting field {idx} must equal "
f"The length of pattern ({pattern_size}) of local detuning {idx} must equal "
f"the number of atom array sites ({num_sites})."
)
return values

# If there is local detuning, the net value of detuning for each atom
# should not exceed a max detuning value
@root_validator(pre=True, skip_on_failure=True)
def net_detuning_must_not_exceed_max_net_detuning(cls, values):
capabilities = values["capabilities"] # device_capabilities

# Extract the program and the fields
program = deepcopy(values)
del program["capabilities"]
program = Program.parse_obj(program)
driving_fields = program.hamiltonian.drivingFields
local_detuning = program.hamiltonian.localDetuning

# If no local detuning, we simply return the values
# because there are separate validators to validate
# the global driving fields in the program
if not len(local_detuning):
return values

detuning_times = [
local_detune.magnitude.time_series.times for local_detune in local_detuning
]

# Merge the time points for different shifting terms and detuning term
all_times = set(sum(detuning_times, []))
for driving_field in driving_fields:
all_times.update(driving_field.detuning.time_series.times)
time_points = sorted(all_times)

# Get the time-dependent functions for the detuning and shifts
_, global_detuning_coefs, local_detuning_coefs = _get_coefs(program, time_points)

# Get the detuning pattern
local_detuning_patterns = [
local_detune.magnitude.pattern for local_detune in local_detuning
]

# For each time point, check that each atom has net detuning less than the threshold
validate_net_detuning_with_warning(
values,
time_points,
global_detuning_coefs,
local_detuning_patterns,
local_detuning_coefs,
capabilities,
)

return values
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,58 @@ def mock_program_data():
return Program.parse_obj(data)


# False example with net detuning larger than the MAX_NET_DETUNING
@pytest.fixture
def mock_program_with_large_net_detuning_data():
data = {
"setup": {
"ahs_register": {
"sites": [[0, 0], [0, 1e-6]],
"filling": [1, 1],
}
},
"hamiltonian": {
"drivingFields": [
{
"amplitude": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 12566400.0, 12566400.0, 0],
},
},
"phase": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 0, -16.0832, -16.0832],
},
},
"detuning": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
},
}
],
"localDetuning": [
{
"magnitude": {
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
"pattern": [0.0, 1.0],
}
}
],
},
}
return Program.parse_obj(data)


def test_program(program_data, device_capabilities_constants):
try:
ProgramValidator(capabilities=device_capabilities_constants, **program_data.dict())
Expand All @@ -88,11 +140,31 @@ def test_program_local_detuning_pattern_has_the_same_length_as_atom_array_sites(
}
}
]
error_message = "The length of pattern (3) of shifting field 0 must equal the number "
error_message = "The length of pattern (3) of local detuning 0 must equal the number "
"of atom array sites (4)."
_assert_program(mock_program_data.dict(), error_message, device_capabilities_constants)


def test_mock_program_with_large_net_detuning_data(
mock_program_with_large_net_detuning_data: Program, device_capabilities_constants
):

warning_message = (
f"Atom {1} has net detuning {-250000000.0} rad/s "
f"at time {0} seconds, which is outside the typical range "
f"[{-device_capabilities_constants.MAX_NET_DETUNING}, "
f"{device_capabilities_constants.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)

with pytest.warns(UserWarning) as e:
ProgramValidator(
capabilities=device_capabilities_constants,
**mock_program_with_large_net_detuning_data.dict(),
)
assert warning_message in str(e[-1].message)


def _assert_program(data, error_message, device_capabilities_constants):
with pytest.raises(ValidationError) as e:
ProgramValidator(capabilities=device_capabilities_constants, **data)
Expand Down

0 comments on commit 9e872f0

Please sign in to comment.