Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: local detuning validation for ahs #244

Merged
merged 14 commits into from
Apr 16, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
MAGNITUDE_PATTERN_VALUE_MIN = 0.0
MAGNITUDE_PATTERN_VALUE_MAX = 1.0

# Maximum net detuning for all atoms
MAX_NET_DETUNING = 2e8


def capabilities_constants() -> CapabilitiesConstants:
return CapabilitiesConstants(
Expand All @@ -60,4 +63,5 @@ def capabilities_constants() -> CapabilitiesConstants:
LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX=LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX,
MAGNITUDE_PATTERN_VALUE_MIN=MAGNITUDE_PATTERN_VALUE_MIN,
MAGNITUDE_PATTERN_VALUE_MAX=MAGNITUDE_PATTERN_VALUE_MAX,
MAX_NET_DETUNING=MAX_NET_DETUNING,
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ class CapabilitiesConstants(BaseModel):

MAGNITUDE_PATTERN_VALUE_MIN: Decimal
MAGNITUDE_PATTERN_VALUE_MAX: Decimal
MAX_NET_DETUNING: Decimal
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,69 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import warnings
from copy import deepcopy

import numpy as np
from braket.ir.ahs.program_v1 import Program
from pydantic.v1 import root_validator

from braket.analog_hamiltonian_simulator.rydberg.rydberg_simulator_helpers import _get_coefs
from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import (
CapabilitiesConstants,
)


def _check_threshold(
values,
time_points,
global_detuning_coefs,
local_detuning_patterns,
local_detuning_coefs,
capabilities,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring and type info, similar to this, if this belong in the helper file let's move it there.

# Given a set of global detuning coefficients (global_detuning_coefs),
# a set of local detuning patterns (local_detuning_patterns)
# and values (local_detuning_coefs), check that all the atoms have net detuninig
# within the capability at all the time points

for time_ind, time in enumerate(time_points):

# Get the contributions from all the global detunings
# (there could be multiple global driving fields) at the time point
values_global_detuning = sum(
[detuning_coef[time_ind] for detuning_coef in global_detuning_coefs]
)

for atom_index in range(len(local_detuning_patterns[0])):
# Get the contributions from local detuning at the time point
values_local_detuning = sum(
[
shift_coef[time_ind] * float(detuning_pattern[atom_index])
for detuning_pattern, shift_coef in zip(
local_detuning_patterns, local_detuning_coefs
)
]
)

# The net detuning is the sum of both the global and local detunings
detuning_to_check = np.real(values_local_detuning + values_global_detuning)

# Issue a warning if the absolute value of the net detuning is
# beyond MAX_NET_DETUNING
if abs(detuning_to_check) > capabilities.MAX_NET_DETUNING:
warnings.warn(
f"Atom {atom_index} has net detuning {detuning_to_check} rad/s "
f"at time {time} seconds, which is outside the typical range "
f"[{-capabilities.MAX_NET_DETUNING}, {capabilities.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)

# Return immediately if there is an atom has net detuning
# exceeding MAX_NET_DETUNING at a time point
return values


class ProgramValidator(Program):
capabilities: CapabilitiesConstants

Expand All @@ -34,7 +89,56 @@ def local_detuning_pattern_has_the_same_length_as_atom_array_sites(cls, values):
pattern_size = len(local_detuning["magnitude"]["pattern"])
if num_sites != pattern_size:
raise ValueError(
f"The length of pattern ({pattern_size}) of shifting field {idx} must equal "
f"The length of pattern ({pattern_size}) of local detuning {idx} must equal "
f"the number of atom array sites ({num_sites})."
)
return values

# If there is local detuning, the net value of detuning for each atom
# should not exceed a max detuning value
@root_validator(pre=True, skip_on_failure=True)
def net_detuning_must_not_exceed_max_net_detuning(cls, values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split this function? Also, let's discuss more on this offline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few questions we discussed offline and we might need confirmation from the science team on these:

  1. Where to cross-verify those schema values, trying to understand what does these detuning patterns, magnitude mean?

capabilities = values["capabilities"] # device_capabilities

# Extract the program and the fields
program = deepcopy(values)
del program["capabilities"]
program = Program.parse_obj(program)
driving_fields = program.hamiltonian.drivingFields
local_detuning = program.hamiltonian.localDetuning

# If no local detuning, we simply return the values
# because there are separate validators to validate
# the global driving fields in the program
if not len(local_detuning):
return values

Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. if there are no local detuning why are we are returning values are we supposed to return an error or warning if this is specific for local detuning?

detuning_times = [
local_detune.magnitude.time_series.times for local_detune in local_detuning
]

# Merge the time points for different shifting terms and detuning term
all_times = set(sum(detuning_times, []))
for driving_field in driving_fields:
all_times.update(driving_field.detuning.time_series.times)
time_points = sorted(all_times)

# Get the time-dependent functions for the detuning and shifts
_, global_detuning_coefs, local_detuning_coefs = _get_coefs(program, time_points)

# Get the detuning pattern
local_detuning_patterns = [
local_detune.magnitude.pattern for local_detune in local_detuning
]

# For each time point, check that each atom has net detuning less than the threshold
_check_threshold(
values,
time_points,
global_detuning_coefs,
local_detuning_patterns,
local_detuning_coefs,
capabilities,
)

return values
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,58 @@ def mock_program_data():
return Program.parse_obj(data)


# False example with net detuning larger than the MAX_NET_DETUNING
@pytest.fixture
def mock_program_with_large_net_detuning_data():
data = {
"setup": {
"ahs_register": {
"sites": [[0, 0], [0, 1e-6]],
"filling": [1, 1],
}
},
"hamiltonian": {
"drivingFields": [
{
"amplitude": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 12566400.0, 12566400.0, 0],
},
},
"phase": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 0, -16.0832, -16.0832],
},
},
"detuning": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
},
}
],
"localDetuning": [
{
"magnitude": {
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
"pattern": [0.0, 1.0],
}
}
],
},
}
return Program.parse_obj(data)


def test_program(program_data, device_capabilities_constants):
try:
ProgramValidator(capabilities=device_capabilities_constants, **program_data.dict())
Expand All @@ -88,11 +140,31 @@ def test_program_local_detuning_pattern_has_the_same_length_as_atom_array_sites(
}
}
]
error_message = "The length of pattern (3) of shifting field 0 must equal the number "
error_message = "The length of pattern (3) of local detuning 0 must equal the number "
"of atom array sites (4)."
_assert_program(mock_program_data.dict(), error_message, device_capabilities_constants)


def test_mock_program_with_large_net_detuning_data(
mock_program_with_large_net_detuning_data: Program, device_capabilities_constants
):

warning_message = (
f"Atom {1} has net detuning {-250000000.0} rad/s "
f"at time {0} seconds, which is outside the typical range "
f"[{-device_capabilities_constants.MAX_NET_DETUNING}, "
f"{device_capabilities_constants.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)

with pytest.warns(UserWarning) as e:
ProgramValidator(
capabilities=device_capabilities_constants,
**mock_program_with_large_net_detuning_data.dict(),
)
assert warning_message in str(e[-1].message)


def _assert_program(data, error_message, device_capabilities_constants):
with pytest.raises(ValidationError) as e:
ProgramValidator(capabilities=device_capabilities_constants, **data)
Expand Down