diff --git a/src/stochatreat/__about__.py b/src/stochatreat/__about__.py index 39e62e7..5d7942c 100644 --- a/src/stochatreat/__about__.py +++ b/src/stochatreat/__about__.py @@ -1,2 +1,2 @@ # pragma: no cover -__version__ = "0.0.19" +__version__ = "0.0.20" diff --git a/src/stochatreat/stochatreat.py b/src/stochatreat/stochatreat.py index 58949dc..a913a5a 100644 --- a/src/stochatreat/stochatreat.py +++ b/src/stochatreat/stochatreat.py @@ -1,5 +1,15 @@ +"""Stratified random assignment of treatments to units. + +This module provides a function to assign treatments to units in a +stratified manner. The function is designed to work with pandas +dataframes and is able to handle multiple strata. There are also different +strategies to deal with misfits (units that are left over after the +stratified assignment procedure). +""" + from __future__ import annotations +import math from typing import Literal import numpy as np @@ -90,7 +100,7 @@ def stochatreat( probs_np = np.array([frac] * len(treatment_ids)) elif probs is not None: probs_np = np.array(probs) - if probs_np.sum() != 1: + if not math.isclose(probs_np.sum(), 1, rel_tol=1e-9): error_msg = "The probabilities must add up to 1" raise ValueError(error_msg) if len(probs_np) != len(treatment_ids): diff --git a/tests/test_assignment.py b/tests/test_assignment.py index e079b2f..f8f2749 100644 --- a/tests/test_assignment.py +++ b/tests/test_assignment.py @@ -30,6 +30,7 @@ def df(request): [0.5, 0.5], [2 / 3, 1 / 3], [0.9, 0.1], + [1 / 2, 1 / 3, 1 / 6], ] # a set of stratum column combinations from the above df fixture to throw at