-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: local detuning validation for ahs #244
Changes from all commits
50919da
ac99143
5854ce5
2ba6021
9b0e615
5dd67c9
1a08243
744cab4
28e5640
65b37d8
57d1908
fc283c3
6efa159
499bba3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,18 @@ | |
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
from copy import deepcopy | ||
|
||
from braket.ir.ahs.program_v1 import Program | ||
from pydantic.v1 import root_validator | ||
|
||
from braket.analog_hamiltonian_simulator.rydberg.rydberg_simulator_helpers import _get_coefs | ||
from braket.analog_hamiltonian_simulator.rydberg.validators.capabilities_constants import ( | ||
CapabilitiesConstants, | ||
) | ||
from braket.analog_hamiltonian_simulator.rydberg.validators.field_validator_util import ( | ||
validate_net_detuning_with_warning, | ||
) | ||
|
||
|
||
class ProgramValidator(Program): | ||
|
@@ -34,7 +40,56 @@ def local_detuning_pattern_has_the_same_length_as_atom_array_sites(cls, values): | |
pattern_size = len(local_detuning["magnitude"]["pattern"]) | ||
if num_sites != pattern_size: | ||
raise ValueError( | ||
f"The length of pattern ({pattern_size}) of shifting field {idx} must equal " | ||
f"The length of pattern ({pattern_size}) of local detuning {idx} must equal " | ||
f"the number of atom array sites ({num_sites})." | ||
) | ||
return values | ||
|
||
# If there is local detuning, the net value of detuning for each atom | ||
# should not exceed a max detuning value | ||
@root_validator(pre=True, skip_on_failure=True) | ||
def net_detuning_must_not_exceed_max_net_detuning(cls, values): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we split this function? Also, let's discuss more on this offline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Few questions we discussed offline and we might need confirmation from the science team on these:
|
||
capabilities = values["capabilities"] # device_capabilities | ||
|
||
# Extract the program and the fields | ||
program = deepcopy(values) | ||
del program["capabilities"] | ||
program = Program.parse_obj(program) | ||
driving_fields = program.hamiltonian.drivingFields | ||
local_detuning = program.hamiltonian.localDetuning | ||
|
||
# If no local detuning, we simply return the values | ||
# because there are separate validators to validate | ||
# the global driving fields in the program | ||
if not len(local_detuning): | ||
return values | ||
|
||
Comment on lines
+64
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
detuning_times = [ | ||
local_detune.magnitude.time_series.times for local_detune in local_detuning | ||
] | ||
|
||
# Merge the time points for different shifting terms and detuning term | ||
all_times = set(sum(detuning_times, [])) | ||
for driving_field in driving_fields: | ||
all_times.update(driving_field.detuning.time_series.times) | ||
time_points = sorted(all_times) | ||
|
||
# Get the time-dependent functions for the detuning and shifts | ||
_, global_detuning_coefs, local_detuning_coefs = _get_coefs(program, time_points) | ||
|
||
# Get the detuning pattern | ||
local_detuning_patterns = [ | ||
local_detune.magnitude.pattern for local_detune in local_detuning | ||
] | ||
|
||
# For each time point, check that each atom has net detuning less than the threshold | ||
validate_net_detuning_with_warning( | ||
values, | ||
time_points, | ||
global_detuning_coefs, | ||
local_detuning_patterns, | ||
local_detuning_coefs, | ||
capabilities, | ||
) | ||
|
||
return values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Do we need this to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but given that the previous function
validate_value_range_with_warning
is public, I figured that we want to be consistent here. Should I change both to private functions or can I keep them like that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let's keep it like that, we can change to private if necessary.