-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: local detuning validation for ahs #244
Changes from 12 commits
50919da
ac99143
5854ce5
2ba6021
9b0e615
5dd67c9
1a08243
744cab4
28e5640
65b37d8
57d1908
fc283c3
6efa159
499bba3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,69 @@ | |
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
import warnings | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
from braket.ir.ahs.program_v1 import Program | ||
from pydantic.v1 import root_validator | ||
|
||
from braket.analog_hamiltonian_simulator.rydberg.rydberg_simulator_helpers import _get_coefs | ||
from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import ( | ||
CapabilitiesConstants, | ||
) | ||
|
||
|
||
def _check_threshold( | ||
values, | ||
time_points, | ||
global_detuning_coefs, | ||
local_detuning_patterns, | ||
local_detuning_coefs, | ||
capabilities, | ||
): | ||
# Given a set of global detuning coefficients (global_detuning_coefs), | ||
# a set of local detuning patterns (local_detuning_patterns) | ||
# and values (local_detuning_coefs), check that all the atoms have net detuninig | ||
maolinml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# within the capability at all the time points | ||
|
||
for time_ind, time in enumerate(time_points): | ||
|
||
# Get the contributions from all the global detunings | ||
# (there could be multiple global driving fields) at the time point | ||
values_global_detuning = sum( | ||
[detuning_coef[time_ind] for detuning_coef in global_detuning_coefs] | ||
) | ||
|
||
for atom_index in range(len(local_detuning_patterns[0])): | ||
# Get the contributions from local detuning at the time point | ||
values_local_detuning = sum( | ||
[ | ||
shift_coef[time_ind] * float(detuning_pattern[atom_index]) | ||
for detuning_pattern, shift_coef in zip( | ||
local_detuning_patterns, local_detuning_coefs | ||
) | ||
] | ||
) | ||
|
||
# The net detuning is the sum of both the global and local detunings | ||
detuning_to_check = np.real(values_local_detuning + values_global_detuning) | ||
|
||
# Issue a warning if the absolute value of the net detuning is | ||
# beyond MAX_NET_DETUNING | ||
if abs(detuning_to_check) > capabilities.MAX_NET_DETUNING: | ||
warnings.warn( | ||
f"Atom {atom_index} has net detuning {detuning_to_check} rad/s " | ||
f"at time {time} seconds, which is outside the typical range " | ||
f"[{-capabilities.MAX_NET_DETUNING}, {capabilities.MAX_NET_DETUNING}]." | ||
f"Numerical instabilities may occur during simulation." | ||
) | ||
|
||
# Return immediately if there is an atom has net detuning | ||
# exceeding MAX_NET_DETUNING at a time point | ||
return values | ||
|
||
|
||
class ProgramValidator(Program): | ||
capabilities: CapabilitiesConstants | ||
|
||
|
@@ -34,7 +89,56 @@ def local_detuning_pattern_has_the_same_length_as_atom_array_sites(cls, values): | |
pattern_size = len(local_detuning["magnitude"]["pattern"]) | ||
if num_sites != pattern_size: | ||
raise ValueError( | ||
f"The length of pattern ({pattern_size}) of shifting field {idx} must equal " | ||
f"The length of pattern ({pattern_size}) of local detuning {idx} must equal " | ||
f"the number of atom array sites ({num_sites})." | ||
) | ||
return values | ||
|
||
# If there is local detuning, the net value of detuning for each atom | ||
# should not exceed a max detuning value | ||
@root_validator(pre=True, skip_on_failure=True) | ||
def net_detuning_must_not_exceed_max_net_detuning(cls, values): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we split this function? Also, let's discuss more on this offline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Few questions we discussed offline and we might need confirmation from the science team on these:
|
||
capabilities = values["capabilities"] # device_capabilities | ||
|
||
# Extract the program and the fields | ||
program = deepcopy(values) | ||
del program["capabilities"] | ||
program = Program.parse_obj(program) | ||
driving_fields = program.hamiltonian.drivingFields | ||
local_detuning = program.hamiltonian.localDetuning | ||
|
||
# If no local detuning, we simply return the values | ||
# because there are separate validators to validate | ||
# the global driving fields in the program | ||
if not len(local_detuning): | ||
return values | ||
|
||
Comment on lines
+64
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
detuning_times = [ | ||
local_detune.magnitude.time_series.times for local_detune in local_detuning | ||
] | ||
|
||
# Merge the time points for different shifting terms and detuning term | ||
all_times = set(sum(detuning_times, [])) | ||
for driving_field in driving_fields: | ||
all_times.update(driving_field.detuning.time_series.times) | ||
time_points = sorted(all_times) | ||
|
||
# Get the time-dependent functions for the detuning and shifts | ||
_, global_detuning_coefs, local_detuning_coefs = _get_coefs(program, time_points) | ||
|
||
# Get the detuning pattern | ||
local_detuning_patterns = [ | ||
local_detune.magnitude.pattern for local_detune in local_detuning | ||
] | ||
|
||
# For each time point, check that each atom has net detuning less than the threshold | ||
_check_threshold( | ||
values, | ||
time_points, | ||
global_detuning_coefs, | ||
local_detuning_patterns, | ||
local_detuning_coefs, | ||
capabilities, | ||
) | ||
|
||
return values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring and type info, similar to this, if this belong in the helper file let's move it there.