Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: noise correction statistics. #1983

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 134 additions & 31 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
numpy~=2.1.0
scipy~=1.14.1
cvxopt~=1.3.2
highs~=1.7.2
highspy~=1.7.2
qdldl~=0.1.7
osqp~=0.6.3
qpsolvers~=4.3.3
tzdata~=2024.1
six~=1.16.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ message MeasurementDetail {
repeated string right_hand_side_targets = 7;

message MeasurementResult {
int64 reach = 1;
double reach = 1;
double standard_deviation = 2;
string metric = 3;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ py_library(
imports = ["../"],
visibility = ["//visibility:public"],
deps = [
requirement("cvxopt"),
requirement("highspy"),
requirement("numpy"),
requirement("osqp"),
requirement("qpsolvers"),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from noiseninja.noised_measurements import SetMeasurementsSpec
from qpsolvers import solve_problem, Problem, Solution
from scipy.sparse import csc_matrix
from threading import Semaphore
from typing import Any

SOLVER = "highs"
HIGHS_SOLVER = "highs"
OSQP_SOLVER = "osqp"
MAX_ATTEMPTS = 10
SEMAPHORE = Semaphore()

Expand Down Expand Up @@ -76,7 +77,7 @@ def _add_measurement_targets(self, set_measurement_spec: SetMeasurementsSpec,
self._add_eq_term(variables, measurement.value)
else:
self._add_loss_term(
np.multiply(variables, 1 / measurement.sigma),
np.multiply(variables, 1.0 / measurement.sigma),
-measurement.value / measurement.sigma)

def _map_sets_to_variables(set_measurement_spec: SetMeasurementsSpec) -> dict[
Expand Down Expand Up @@ -149,20 +150,16 @@ def _add_gt_term(self, variables: np.array):
self.G.append(variables)
self.h.append([0])

def _solve(self):
x0 = np.random.randn(self.num_variables)
return self._solve_with_initial_value(x0)

def _solve_with_initial_value(self, x0) -> Solution:
def _solve_with_initial_value(self, solver_name, x0) -> Solution:
problem = self._problem()
solution = solve_problem(problem, solver=SOLVER, verbose=False)
solution = solve_problem(problem, solver=solver_name, initvals=x0, verbose=False)
return solution

def _problem(self):
problem: Problem
if len(self.A) > 0:
problem = Problem(
self.P, self.q, np.array(self.G), np.array(self.h),
csc_matrix(self.P), self.q, csc_matrix(np.array(self.G)), np.array(self.h),
np.array(self.A), np.array(self.b))
else:
problem = Problem(
Expand All @@ -181,14 +178,29 @@ def solve(self) -> Solution:
# TODO: check if qpsolvers is thread safe,
# and remove this semaphore.
SEMAPHORE.acquire()
solution = self._solve()
solution = self._solve_with_initial_value(HIGHS_SOLVER, self.base_value)
SEMAPHORE.release()

if solution.found:
break
else:
attempt_count += 1

# If the highs solver does not converge, switch to the osqp solver which
# is more robust.
if not solution.found:
attempt_count = 0
while attempt_count < MAX_ATTEMPTS:
SEMAPHORE.acquire()
solution = self._solve_with_initial_value(OSQP_SOLVER, self.base_value)
SEMAPHORE.release()

if solution.found:
break
else:
attempt_count += 1

# Raise the exception when both solvers do not converge.
if not solution.found:
raise SolutionNotFoundError(solution)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,96 @@ def get_cover_relationships(edp_combinations: list[FrozenSet[str]]) -> list[
return cover_relationships


def get_subset_relationships(edp_combinations: list[FrozenSet[str]]):
"""Returns a list of tuples where first element in the tuple is the parent
and second element is the subset."""
subset_relationships = []

for comb1, comb2 in combinations(edp_combinations, 2):
if comb1.issubset(comb2):
subset_relationships.append((comb2, comb1))
elif comb2.issubset(comb1):
subset_relationships.append((comb1, comb2))
return subset_relationships


def is_cover(target_set, possible_cover):
"""Checks if a collection of sets covers a target set.

Args:
target_set: The set that should be covered.
possible_cover: A collection of sets that may cover the target set.

Returns:
True if the union of the sets in `possible_cover` equals `target_set`,
False otherwise.
"""
union_of_possible_cover = reduce(
lambda x, y: x.union(y), possible_cover
)
if union_of_possible_cover == target_set:
return True
else:
return False


def get_covers(target_set, other_sets):
"""Finds all combinations of sets from `other_sets` that cover `target_set`.

This function identifies all possible combinations of sets within `other_sets`
whose union equals the `target_set`. It only considers sets that are subsets of
the `target_set`.

Args:
target_set: The set that needs to be covered.
other_sets: A collection of sets that may be used to cover the `target_set`.

Returns:
A list of tuples, where each tuple represents a covering relationship.
The first element of the tuple is the `target_set`, and the second element
is a tuple containing the sets from `other_sets` that cover it.
"""
def generate_all_length_combinations(data):
"""Generates all possible combinations of elements from a list.

Args:
data: The list of elements.

Returns:
A list of tuples, where each tuple represents a combination of elements.
"""
return [
comb for r in range(1, len(data) + 1) for comb in
combinations(data, r)
]

cover_relationship = []
all_subsets_of_possible_covered = [other_set for other_set in other_sets
if
other_set.issubset(target_set)]
possible_covers = generate_all_length_combinations(
all_subsets_of_possible_covered)
for possible_cover in possible_covers:
if is_cover(target_set, possible_cover):
cover_relationship.append((target_set, possible_cover))
return cover_relationship


def get_cover_relationships(edp_combinations: list[FrozenSet[str]]):
"""Returns covers as defined here: # https://en.wikipedia.org/wiki/Cover_(topology).
For each set (s_i) in the list, enumerate combinations of all sets excluding this one.
For each of these considered combinations, take their union and check if it is equal to
s_i. If so, this combination is a cover of s_i.
"""
cover_relationships = []
for i in range(len(edp_combinations)):
possible_covered = edp_combinations[i]
other_sets = edp_combinations[:i] + edp_combinations[i + 1:]
cover_relationship = get_covers(possible_covered, other_sets)
cover_relationships.extend(cover_relationship)
return cover_relationships


class MetricReport:
"""Represents a metric sub-report view (e.g., MRC, AMI) within a report.

Expand Down Expand Up @@ -310,7 +400,6 @@ def __init__(
measurement_index += 1

self._num_vars = measurement_index

def get_metric_report(self, metric: str) -> "MetricReport":
return self._metric_reports[metric]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
ami = "ami"
mrc = "mrc"


# TODO(@ple13): Extend the class to support custom measurements and composite
# set operations such as incremental.
class ReportSummaryProcessor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1179,61 +1179,6 @@ def test_correct_report_with_whole_campaign_has_more_edp_combinations(self):

self._assertReportsAlmostEqual(expected, corrected, corrected.to_array())

def test_allows_incorrect_time_series(self):
ami = "ami"
report = Report(
metric_reports={
ami: MetricReport(
reach_time_series={
frozenset({EDP_TWO}): [
Measurement(0.00, 1, "measurement_01"),
Measurement(3.30, 1, "measurement_02"),
Measurement(4.00, 1, "measurement_03"),
],
frozenset({EDP_ONE}): [
Measurement(0.00, 1, "measurement_04"),
Measurement(3.30, 1, "measurement_05"),
Measurement(1.00, 1, "measurement_06"),
],
},
reach_whole_campaign={},
)
},
metric_subsets_by_parent={},
cumulative_inconsistency_allowed_edp_combinations=set(
frozenset({EDP_ONE})),
)

# The corrected report should be consistent: all the time series reaches are
# monotonic increasing, e.g. reach[edp1][i] <= reach[edp1][i+1], except for
# the one in the exception list, e.g. edp1.
corrected = report.get_corrected_report()

expected = Report(
metric_reports={
ami: MetricReport(
reach_time_series={
frozenset({EDP_TWO}): [
Measurement(0.00, 1, "measurement_01"),
Measurement(3.30, 1, "measurement_02"),
Measurement(4.00, 1, "measurement_03"),
],
frozenset({EDP_ONE}): [
Measurement(0.00, 1, "measurement_04"),
Measurement(3.30, 1, "measurement_05"),
Measurement(1.00, 1, "measurement_06"),
],
},
reach_whole_campaign={},
)
},
metric_subsets_by_parent={},
cumulative_inconsistency_allowed_edp_combinations=set(
frozenset({EDP_ONE})),
)

self._assertReportsAlmostEqual(expected, corrected, corrected.to_array())

def test_can_correct_related_metrics(self):
ami = "ami"
mrc = "mrc"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
load("@pip//:requirements.bzl", "requirement")
load("@rules_python//python:defs.bzl", "py_library")
load("@rules_python//python:defs.bzl", "py_test")


py_test(
name = "test_post_process_origin_report",
srcs = ["test_post_process_origin_report.py"],
Expand All @@ -10,6 +13,18 @@ py_test(
],
)

py_test(
name = "test_error_analysis",
srcs = ["test_error_analysis.py"],
data = [":sample_reports"],
deps = [
"//src/main/proto/wfa/measurement/reporting/postprocessing/v2alpha:report_summary_py_pb2",
"//src/main/python/wfa/measurement/reporting/postprocessing/tools:post_process_origin_report",
requirement("openpyxl"),
],
)


filegroup(
name = "sample_reports",
srcs = glob(["*.json"]),
Expand Down
Loading
Loading