Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: local detuning validation for ahs #244

Merged
merged 14 commits into from
Apr 16, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
MAGNITUDE_PATTERN_VALUE_MIN = 0.0
MAGNITUDE_PATTERN_VALUE_MAX = 1.0

# Maximum net detuning for all atoms
MAX_NET_DETUNING = 2e8


def capabilities_constants() -> CapabilitiesConstants:
return CapabilitiesConstants(
Expand All @@ -60,4 +63,5 @@ def capabilities_constants() -> CapabilitiesConstants:
LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX=LOCAL_MAGNITUDE_SEQUENCE_VALUE_MAX,
MAGNITUDE_PATTERN_VALUE_MIN=MAGNITUDE_PATTERN_VALUE_MIN,
MAGNITUDE_PATTERN_VALUE_MAX=MAGNITUDE_PATTERN_VALUE_MAX,
MAX_NET_DETUNING=MAX_NET_DETUNING,
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ class CapabilitiesConstants(BaseModel):

MAGNITUDE_PATTERN_VALUE_MIN: Decimal
MAGNITUDE_PATTERN_VALUE_MAX: Decimal
MAX_NET_DETUNING: Decimal
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import warnings
from copy import deepcopy

import numpy as np
from braket.ir.ahs.program_v1 import Program
from pydantic.v1 import root_validator

from braket.analog_hamiltonian_simulator.rydberg.rydberg_simulator_helpers import _get_coefs
from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import (
CapabilitiesConstants,
)
Expand All @@ -38,3 +43,60 @@ def local_detuning_pattern_has_the_same_length_as_atom_array_sites(cls, values):
f"the number of atom array sites ({num_sites})."
)
return values

# If there is local detuning, the net value of detuning for each atom
# should not exceed a certain value
@root_validator(pre=True, skip_on_failure=True)
def net_detuning_must_not_exceed_max_net_detuning(cls, values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split this function? Also, let's discuss more on this offline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few questions we discussed offline and we might need confirmation from the science team on these:

  1. Where to cross-verify those schema values, trying to understand what does these detuning patterns, magnitude mean?

capabilities = values["capabilities"] # device_capabilities

# Extract the program and the fields
program = deepcopy(values)
del program["capabilities"]
program = Program.parse_obj(program)
driving_fields = program.hamiltonian.drivingFields
local_detuning = program.hamiltonian.localDetuning

# If no local detuning, return
if not len(local_detuning):
return values

Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. if there are no local detuning why are we are returning values are we supposed to return an error or warning if this is specific for local detuning?

detuning_times = [
local_detune.magnitude.time_series.times for local_detune in local_detuning
]

# Merge the time points for different shifting terms and detuning term
time_points = sorted(list(set(sum(detuning_times, []))))
for driving_field in driving_fields:
time_points = sorted(list(set(time_points + driving_field.detuning.time_series.times)))

# Get the time-dependent functions for the detuning and shifts
_, detuning_coefs, shift_coefs = _get_coefs(program, time_points)

# Get the detuning pattern
detuning_patterns = [local_detune.magnitude.pattern for local_detune in local_detuning]

# For each time point, check that each atom has net detuning less than the threshold
for time_ind, time in enumerate(time_points):
for atom_index in range(len(detuning_patterns[0])):
# Get the contributions from global detuning at the time point
detuning_to_check = 0
for detuning_coef in detuning_coefs:
detuning_to_check += detuning_coef[time_ind]

# Get the contributions from local detuning at the time point
for detuning_pattern, shift_coef in zip(detuning_patterns, shift_coefs):
detuning_to_check += shift_coef[time_ind] * float(detuning_pattern[atom_index])
Copy link
Contributor

@virajvchaudhari virajvchaudhari Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The method name says its checking if the net values doesn't exceed the max detuning value, but it retrieving the contributions, I need a bit more context on this to understand why this is part of this method.
    And similarly, for # Merge the time points for different shifting terms and detuning term


# Issue a warning if the net detuning is beyond MAX_NET_DETUNING
detuning_to_check = np.real(detuning_to_check)
if abs(detuning_to_check) > capabilities.MAX_NET_DETUNING:
warnings.warn(
f"Atom {atom_index} has net detuning {detuning_to_check} rad/s "
f"at time {time} seconds, which is outside the typical range "
f"[{-capabilities.MAX_NET_DETUNING}, {capabilities.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)
return values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we require this return?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I did some tests, and yes, this return is necessary. The idea is that we want to raise a warning immediately once we found an atom has net detuning larger than the allowed value, and stop the validator for the net detuning. If we don't have this return, the unit test will fail.

An alternative approach is to break the for loop like we did in this validator


return values
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,58 @@ def mock_program_data():
return Program.parse_obj(data)


# False example with net detuning larger than the MAX_NET_DETUNING
@pytest.fixture
def mock_program_with_large_net_detuning_data():
data = {
"setup": {
"ahs_register": {
"sites": [[0, 0], [0, 1e-6]],
"filling": [1, 1],
}
},
"hamiltonian": {
"drivingFields": [
{
"amplitude": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 12566400.0, 12566400.0, 0],
},
},
"phase": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [0, 0, -16.0832, -16.0832],
},
},
"detuning": {
"pattern": "uniform",
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
},
}
],
"localDetuning": [
{
"magnitude": {
"time_series": {
"times": [0, 1e-07, 3.9e-06, 4e-06],
"values": [-125000000, -125000000, 125000000, 125000000],
},
"pattern": [0.0, 1.0],
}
}
],
},
}
return Program.parse_obj(data)


def test_program(program_data, device_capabilities_constants):
try:
ProgramValidator(capabilities=device_capabilities_constants, **program_data.dict())
Expand All @@ -88,11 +140,31 @@ def test_program_local_detuning_pattern_has_the_same_length_as_atom_array_sites(
}
}
]
error_message = "The length of pattern (3) of shifting field 0 must equal the number "
error_message = "The length of pattern (3) of local detuning 0 must equal the number "
"of atom array sites (4)."
_assert_program(mock_program_data.dict(), error_message, device_capabilities_constants)


def test_mock_program_with_large_net_detuning_data(
mock_program_with_large_net_detuning_data: Program, device_capabilities_constants
):

warning_message = (
f"Atom {1} has net detuning {-250000000.0} rad/s "
f"at time {0} seconds, which is outside the typical range "
f"[{-device_capabilities_constants.MAX_NET_DETUNING}, "
f"{device_capabilities_constants.MAX_NET_DETUNING}]."
f"Numerical instabilities may occur during simulation."
)

with pytest.warns(UserWarning) as e:
ProgramValidator(
capabilities=device_capabilities_constants,
**mock_program_with_large_net_detuning_data.dict(),
)
assert warning_message in str(e[-1].message)


def _assert_program(data, error_message, device_capabilities_constants):
with pytest.raises(ValidationError) as e:
ProgramValidator(capabilities=device_capabilities_constants, **data)
Expand Down
Loading