Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added logging suppressor for falling back to cpu #135

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: [3.8, 3.9, "3.10", 3.11] # quoted 3.10 needed due to this bug: https://github.com/actions/runner/issues/1989
python-version: [3.9, "3.10", 3.11] # quoted 3.10 needed due to this bug: https://github.com/actions/runner/issues/1989

steps:
- name: Checkout 🛎️
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish_package_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
id-token: write
strategy:
matrix:
python-version: [3.8]
python-version: [3.9]

steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 0 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
# add sourcecode to path
import sys, os

# sys.path.insert(0, os.path.abspath("../multidms"))
# sys.path.insert(0, "{0}/..".format(os.path.abspath(".")))
sys.path.insert(0, "{}/..".format(os.path.abspath(".")))

# -- Project information -----------------------------------------------------
Expand All @@ -33,9 +31,7 @@
"sphinx.ext.mathjax",
"sphinx.ext.githubpages",
"sphinxcontrib.bibtex",
# "sphinx.ext.viewcode",
"sphinx.ext.napoleon",
# "matplotlib.sphinxext.plot_directive",
"nbsphinx",
"nbsphinx_link",
]
Expand Down
16 changes: 12 additions & 4 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
import numpy as onp
import altair as alt

import logging

logging.getLogger("jax._src.xla_bridge").addFilter(
logging.Filter(
"An NVIDIA GPU may be present on this machine, "
"but a CUDA-enabled jaxlib is not installed. Falling back to cpu."
)
)


PARAMETER_NAMES_FOR_PLOTTING = {
"scale_coeff_lasso_shift": "Lasso Penalty",
Expand Down Expand Up @@ -922,10 +931,9 @@ def mut_type(mut):
.assign(mut_type=lambda x: x.mutation.apply(mut_type))
.reset_index()
.groupby(by=feature_cols)
.apply(
sparsity, include_groups=True
) # TODO This throws deprecation warning
.drop(columns=feature_cols + ["mutation"])
# .apply(sparsity, include_groups=True)
.apply(sparsity, include_groups=False)
# .drop(columns=feature_cols + ["mutation"])
.reset_index(drop=False)
.melt(id_vars=feature_cols, var_name="mut_param", value_name="sparsity")
)
Expand Down
405 changes: 161 additions & 244 deletions notebooks/simulation_validation.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
]
keywords = [
"multidms",
Expand All @@ -32,10 +33,10 @@ keywords = [


# Software Dependencies
requires-python = ">=3.8"
requires-python = ">=3.9"
dependencies = [
"polyclonal",
"jax",
"jax[cpu]==0.4.24",
"jaxopt",
"typing_extensions",
"numpy",
Expand Down
Loading