Skip to content

Commit

Permalink
ruff 0.8.0 and fix some tests (#1240)
Browse files Browse the repository at this point in the history
* ruff0.8.0

* better tests
  • Loading branch information
juanitorduz authored Nov 26, 2024
1 parent a340272 commit 7abb6e7
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- --exclude=docs/
- --exclude=scripts/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.0
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from pymc_marketing import clv, mmm
from pymc_marketing.version import __version__

__all__ = ["clv", "mmm", "__version__"]
__all__ = ["__version__", "clv", "mmm"]
6 changes: 3 additions & 3 deletions pymc_marketing/clv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@
)

__all__ = (
"BetaGeoModel",
"BetaGeoBetaBinomModel",
"ParetoNBDModel",
"BetaGeoModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
"ParetoNBDModel",
"ShiftedBetaGeoModelIndividual",
"customer_lifetime_value",
"plot_customer_exposure",
"plot_frequency_recency_matrix",
"plot_expected_purchases",
"plot_frequency_recency_matrix",
"plot_probability_alive_matrix",
"rfm_segments",
"rfm_summary",
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytensor.graph import vectorize_graph
from pytensor.tensor.random.op import RandomVariable

__all__ = ["ContContract", "ContNonContract", "ParetoNBD", "BetaGeoBetaBinom"]
__all__ = ["BetaGeoBetaBinom", "ContContract", "ContNonContract", "ParetoNBD"]


class ContNonContractRV(RandomVariable):
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/clv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from pymc_marketing.clv.models.shifted_beta_geo import ShiftedBetaGeoModelIndividual

__all__ = (
"CLVModel",
"BetaGeoBetaBinomModel",
"BetaGeoModel",
"CLVModel",
"GammaGammaModel",
"GammaGammaModelIndividual",
"BetaGeoModel",
"ParetoNBDModel",
"ShiftedBetaGeoModelIndividual",
)
2 changes: 1 addition & 1 deletion pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

__all__ = [
"plot_customer_exposure",
"plot_expected_purchases",
"plot_frequency_recency_matrix",
"plot_probability_alive_matrix",
"plot_expected_purchases",
]


Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from numpy import datetime64

__all__ = [
"to_xarray",
"customer_lifetime_value",
"rfm_segments",
"rfm_summary",
"rfm_train_test_split",
"to_xarray",
]


Expand Down
20 changes: 10 additions & 10 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,39 @@
from pymc_marketing.mmm.validating import validation_method_X, validation_method_y

__all__ = [
"MediaTransformation",
"MediaConfigList",
"MediaConfig",
"MMM",
"AdstockTransformation",
"BaseValidateMMM",
"DelayedAdstock",
"GeometricAdstock",
"HillSaturation",
"HillSaturationSigmoid",
"LogisticSaturation",
"InverseScaledLogisticSaturation",
"MMM",
"LinearTrend",
"LogisticSaturation",
"MMMModelBuilder",
"MediaConfig",
"MediaConfigList",
"MediaTransformation",
"MichaelisMentenSaturation",
"MonthlyFourier",
"RootSaturation",
"SaturationTransformation",
"TanhSaturation",
"TanhSaturationBaselined",
"saturation_from_dict",
"register_saturation_transformation",
"WeibullCDFAdstock",
"WeibullPDFAdstock",
"adstock_from_dict",
"register_adstock_transformation",
"YearlyFourier",
"adstock_from_dict",
"base",
"mmm",
"preprocessing",
"preprocessing_method_X",
"preprocessing_method_y",
"register_adstock_transformation",
"register_saturation_transformation",
"saturation_from_dict",
"validating",
"validation_method_X",
"validation_method_y",
"LinearTrend",
]
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from pymc_marketing.model_builder import ModelBuilder

__all__ = ["MMMModelBuilder", "BaseValidateMMM"]
__all__ = ["BaseValidateMMM", "MMMModelBuilder"]

from pydantic import Field, validate_call

Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior

__all__ = ["BaseMMM", "MMM"]
__all__ = ["MMM", "BaseMMM"]


class BaseMMM(BaseValidateMMM):
Expand Down
6 changes: 3 additions & 3 deletions pymc_marketing/mmm/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from sklearn.preprocessing import MaxAbsScaler, StandardScaler

__all__ = [
"preprocessing_method_X",
"preprocessing_method_y",
"MaxAbsScaleTarget",
"MaxAbsScaleChannels",
"MaxAbsScaleTarget",
"StandardizeControls",
"preprocessing_method_X",
"preprocessing_method_y",
]


Expand Down
8 changes: 4 additions & 4 deletions pymc_marketing/mmm/validating.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import pandas as pd

__all__ = [
"validation_method_X",
"validation_method_y",
"ValidateChannelColumns",
"ValidateControlColumns",
"ValidateTargetColumn",
"ValidateDateColumn",
"ValidateChannelColumns",
"ValidateTargetColumn",
"validation_method_X",
"validation_method_y",
]


Expand Down
5 changes: 1 addition & 4 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def test_default_prefix(adstock) -> None:


def test_adstock_no_negative_lmax():
with pytest.raises(
ValidationError,
match="1 validation error for __init__\\nl_max\\n Input should be greater than 0",
):
with pytest.raises(ValidationError, match=".*Input should be greater than 0.*"):
DelayedAdstock(l_max=-1)


Expand Down
10 changes: 2 additions & 8 deletions tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,10 +576,7 @@ def test_cant_reset_distribution() -> None:


def test_nonstring_distribution() -> None:
with pytest.raises(
ValidationError,
match="1 validation error for __init__\\n1\\n Input should be a valid string",
):
with pytest.raises(ValidationError, match=".*Input should be a valid string.*"):
Prior(pm.Normal)


Expand All @@ -590,10 +587,7 @@ def test_change_the_transform() -> None:


def test_nonstring_transform() -> None:
with pytest.raises(
ValidationError,
match="1 validation error for __init__\\ntransform\\n Input should be a valid string",
):
with pytest.raises(ValidationError, match=".*Input should be a valid string.*"):
Prior("Normal", transform=pm.math.log)


Expand Down

0 comments on commit 7abb6e7

Please sign in to comment.