Skip to content

Commit

Permalink
Add more unit tests + bugfixes
Browse files Browse the repository at this point in the history
This commit adds additional unit tests for the fisher information calculations when doing an approximate fisher information calculation. This also fixes a few bugs revealed by the new and existing fisher information unit tests. Also includes a refactoring of this test module splitting up the three different parts of edesign tools into separate classes.
  • Loading branch information
Corey Ostrove committed Sep 15, 2023
1 parent d158976 commit 3b16b1c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 28 deletions.
35 changes: 26 additions & 9 deletions pygsti/tools/edesigntools.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def calculate_fisher_information_per_circuit(regularized_model, circuits, approx
split_fisher_info_terms = accumulate_fim_matrix_per_circuit(split_circuit_list, num_params,
outcomes, ps, js,
printer,
hs, approx=True)
approx=True)
else:
split_fisher_info_terms, total_hterm = accumulate_fim_matrix_per_circuit(split_circuit_list, num_params,
outcomes, ps, js,
Expand Down Expand Up @@ -297,7 +297,7 @@ def calculate_fisher_information_per_circuit(regularized_model, circuits, approx
fisher_info_terms = accumulate_fim_matrix_per_circuit(circuits, num_params,
outcomes, ps, js,
printer,
hs, approx=True)
approx=True)
else:
fisher_info_terms, total_hterm = accumulate_fim_matrix_per_circuit(circuits, num_params,
outcomes, ps, js,
Expand Down Expand Up @@ -396,8 +396,15 @@ def calculate_fisher_information_matrix(model, circuits, num_shots=1, term_cache
needed_circuits = [c for c in circuits if c not in term_cache]
if len(needed_circuits):
printer.log('Adding needed terms to the per-circuit term cache.',3)
new_terms = calculate_fisher_information_per_circuit(regularized_model, needed_circuits,
approx, verbosity=verbosity, comm=comm, mem_limit=mem_limit)

#might also return hessian terms if approx is False, but we currently aren't using this in
#this function.
if approx:
new_terms = calculate_fisher_information_per_circuit(regularized_model, needed_circuits,
approx, verbosity=verbosity, comm=comm, mem_limit=mem_limit)
else:
new_terms, _ = calculate_fisher_information_per_circuit(regularized_model, needed_circuits,
approx, verbosity=verbosity, comm=comm, mem_limit=mem_limit)
if comm is None or comm.Get_rank()==0:
term_cache.update(new_terms)

Expand Down Expand Up @@ -427,10 +434,12 @@ def calculate_fisher_information_matrix(model, circuits, num_shots=1, term_cache
for i, ckt_chunk in enumerate(chunked_circuit_lists):
printer.show_progress(iteration = i, total=len(chunked_circuit_lists), bar_length=50,
suffix= f'Circuit chunk {i+1} out of {len(chunked_circuit_lists)}')

fim_term_for_chunk = _calculate_fisher_information_per_chunk(regularized_model, ckt_chunk,
if approx:
fim_term_for_chunk = _calculate_fisher_information_per_chunk(regularized_model, ckt_chunk,
approx, num_shots, verbosity=verbosity, comm=comm, mem_limit=mem_limit)
else:
fim_term_for_chunk, _ = _calculate_fisher_information_per_chunk(regularized_model, ckt_chunk,
approx, num_shots, verbosity=verbosity, comm=comm, mem_limit=mem_limit)

# Collect all terms, do this on rank zero:
if comm is None or comm.Get_rank() == 0:
fisher_information += fim_term_for_chunk
Expand Down Expand Up @@ -461,6 +470,10 @@ def calculate_fisher_information_matrices_by_L(model, circuit_lists, Ls, num_sho
Circuit lists for the experiment design for each L. Most likely from the value of
the `circuit_lists` attribute of most experiment design objects.
Ls : list of ints
A list of integer values corresponding to the circuit lengths associated with each circuit list
as past in with circuit_lists.
num_shots: int or dict
If int, specifies how many shots each circuit gets. If dict, keys must be circuits
and values are per-circuit counts.
Expand Down Expand Up @@ -537,8 +550,12 @@ def calculate_fisher_information_matrices_by_L(model, circuit_lists, Ls, num_sho
term_cache = {}
needed_circuits = [c for ckt_list in circuit_lists for c in ckt_list if c not in term_cache]
if len(needed_circuits):
new_terms = calculate_fisher_information_per_circuit(regularized_model, needed_circuits, approx, verbosity=verbosity,
comm=comm, mem_limit=mem_limit)
if approx:
new_terms = calculate_fisher_information_per_circuit(regularized_model, needed_circuits, approx, verbosity=verbosity,
comm=comm, mem_limit=mem_limit)
else:
new_terms, _ = calculate_fisher_information_per_circuit(regularized_model, needed_circuits, approx, verbosity=verbosity,
comm=comm, mem_limit=mem_limit)
if comm is None or comm.Get_rank()==0:
term_cache.update(new_terms)
#should have already used the comm in the construction of the term cache, so this is just an accumulation
Expand Down
61 changes: 42 additions & 19 deletions test/unit/tools/test_edesigntools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..util import BaseCase


class EdesignToolsTester(BaseCase):
class ExperimentDesignTimeEstimationTester(BaseCase):

def test_time_estimation(self):
edesign = smq2Q_XYICNOT.create_gst_experiment_design(256)
Expand Down Expand Up @@ -110,43 +110,66 @@ def test_time_estimation(self):
)
self.assertGreater(time3, time1)

def test_fisher_information(self):
target_model = smq1Q_XYI.target_model('TP')
edesign = smq1Q_XYI.create_gst_experiment_design(8)
class FisherInformationTester(BaseCase):

def setUp(self):

self.target_model = smq1Q_XYI.target_model('full TP')
self.edesign = smq1Q_XYI.create_gst_experiment_design(8)
self.Ls = [1,2,4,8]
self.regularized_model = self.target_model.copy().depolarize(spam_noise=1e-3)

def test_calculate_fisher_information_matrix(self):

# Basic usage
start = time.time()
fim1 = et.calculate_fisher_information_matrix(target_model, edesign.all_circuits_needing_data)
fim1 = et.calculate_fisher_information_matrix(self.target_model, self.edesign.all_circuits_needing_data,
regularize_spam= True)
fim1_time = time.time() - start

# Try external regularized model version
regularized_model = target_model.copy().depolarize(spam_noise=1e-3)
fim2 = et.calculate_fisher_information_matrix(regularized_model, edesign.all_circuits_needing_data,
fim2 = et.calculate_fisher_information_matrix(self.regularized_model, self.edesign.all_circuits_needing_data,
regularize_spam=False)
self.assertArraysAlmostEqual(fim1, fim2)

# Try pre-cached version
fim3_terms = et.calculate_fisher_information_per_circuit(regularized_model, edesign.all_circuits_needing_data)
fim3_terms, _ = et.calculate_fisher_information_per_circuit(self.regularized_model, self.edesign.all_circuits_needing_data)
start = time.time()
fim3 = et.calculate_fisher_information_matrix(target_model, edesign.all_circuits_needing_data, term_cache=fim3_terms)
fim3 = et.calculate_fisher_information_matrix(self.target_model, self.edesign.all_circuits_needing_data, term_cache=fim3_terms)
fim3_time = time.time() - start

self.assertArraysAlmostEqual(fim1, fim3)
self.assertLess(10*fim3_time, fim1_time) # Cached version should be very fast compared to uncached

def test_calculate_fisher_info_by_L(self):

fim1 = et.calculate_fisher_information_matrix(self.target_model, self.edesign.all_circuits_needing_data,
regularize_spam= True)

# Try by-L version
fim_by_L = et.calculate_fisher_information_matrices_by_L(target_model, edesign.all_circuits_needing_data)
fim_by_L = et.calculate_fisher_information_matrices_by_L(self.target_model, self.edesign.circuit_lists, self.Ls)
self.assertArraysAlmostEqual(fim1, fim_by_L[8])

# Try pre-cached by-L version
start = time.time()
fim_by_L2 = et.calculate_fisher_information_matrices_by_L(target_model, edesign.all_circuits_needing_data, term_cache=fim3_terms)
fim_by_L2_time = time.time() - start
for k,v in fim_by_L2.items():
self.assertArraysAlmostEqual(v, fim_by_L[k])
self.assertLess(5*fim_by_L2_time, fim1_time) # Cached version should be very fast compared to uncached
#test approximate versions of the fisher information calculation.
def test_fisher_information_approximate(self):

#Test approximate fisher information calculations:
fim_approx = et.calculate_fisher_information_matrix(self.target_model, self.edesign.all_circuits_needing_data,
approx=True)

#test per-circuit
fim_approx_per_circuit = et.calculate_fisher_information_per_circuit(self.regularized_model,
self.edesign.all_circuits_needing_data,
approx=True)

#Test by L:
fim_approx_by_L = et.calculate_fisher_information_matrices_by_L(self.target_model, self.edesign.circuit_lists, self.Ls,
approx=True)
self.assertArraysAlmostEqual(fim_approx, fim_approx_by_L[8])


class EdesignPaddingTester(BaseCase):


def test_generic_design_padding(self):
# Create a series of designs with some overlap when they will be padded out
design_124 = CircuitListsDesign([[
Expand Down

0 comments on commit 3b16b1c

Please sign in to comment.