Skip to content

Commit

Permalink
[BUG] Fix pyproject package structure and failures in tutorial notebo…
Browse files Browse the repository at this point in the history
…oks (#1615)

* add nbconvert
* Show environment
* test which python in sh file
* remove py_modules
* add [tool.setuptools].packages
* packages to package-dir
* package discovery by exclude
* clean up build_tools/run_examples.sh
* remove warning in ar.ipynb
  • Loading branch information
XinyuWuu authored Aug 26, 2024
1 parent ccb50ab commit 119fa89
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev,github-actions,graph,mqf2]"
- name: Show dependencies
run: python -m pip list

- name: Run example notebooks
run: build_tools/run_examples.sh
shell: bash
Expand Down
5 changes: 1 addition & 4 deletions docs/source/tutorials/ar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
21 changes: 12 additions & 9 deletions docs/source/tutorials/building.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down Expand Up @@ -1034,16 +1031,19 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"While not required, to give the user transparancy over these additional hyperparameters, it is worth passing them explicitly instead of implicitly in ``**kwargs``\n",
"\n",
"They are described in detail in the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`. \n",
"They are described in detail in the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`.\n",
"\n",
".. automethod:: pytorch_forecasting.models.base_model.BaseModel.__init__\n",
" :noindex:\n",
" \n",
"\n",
"You can simply copy this docstring into your model implementation:"
]
},
Expand Down Expand Up @@ -2238,15 +2238,18 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"Now that we have established the basics, we can move on to more advanced use cases, e.g. how can we make use of covariates - static and continuous alike. We can leverage the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` for this. The difference to the :py:class:`~pytorch_forecasting.models.base_model.BaseModel` is a :py:meth:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates.from_dataset` method that pre-defines hyperparameters for architectures with covariates.\n",
"\n",
".. autoclass:: pytorch_forecasting.models.base_model.BaseModelWithCovariates\n",
" :noindex:\n",
" :members: from_dataset\n",
" \n",
"\n",
"\n",
"Here is a from the BaseModelWithCovariates docstring to copy:"
]
Expand Down
5 changes: 1 addition & 4 deletions docs/source/tutorials/deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
5 changes: 1 addition & 4 deletions docs/source/tutorials/nhits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
16 changes: 8 additions & 8 deletions docs/source/tutorials/stallion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@
},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths"
]
},
{
Expand Down Expand Up @@ -1557,16 +1554,19 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"Hyperparamter tuning with [optuna](https://optuna.org/) is directly build into pytorch-forecasting. For example, we can use the \n",
"Hyperparamter tuning with [optuna](https://optuna.org/) is directly build into pytorch-forecasting. For example, we can use the\n",
":py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters` function to optimize the TFT's hyperparameters.\n",
"\n",
".. code-block:: python\n",
"\n",
" import pickle\n",
" \n",
"\n",
" from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters\n",
"\n",
" # create study\n",
Expand Down Expand Up @@ -1917,7 +1917,7 @@
"source": [
"# calcualte metric by which to display\n",
"predictions = best_tft.predict(val_dataloader, return_y=True)\n",
"mean_losses = SMAPE(reduction=\"none\")(predictions.output, predictions.y).mean(1)\n",
"mean_losses = SMAPE(reduction=\"none\").loss(predictions.output, predictions.y[0]).mean(1)\n",
"indices = mean_losses.argsort(descending=True) # sort losses\n",
"for idx in range(10): # plot 10 examples\n",
" best_tft.plot_prediction(\n",
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ dev = [
"pyarrow",
# jupyter notebook
"ipykernel",
"nbconvert",
"black[extras]",
# documentatation
"sphinx",
Expand All @@ -103,8 +104,8 @@ github-actions = ["pytest-github-actions-annotate-failures"]
graph = ["networkx"]
mqf2 = ["cpflows"]

[tool.setuptools]
py-modules = ["pytorch_forecasting"]
[tool.setuptools.packages.find]
exclude = ["build_tools"]

[build-system]
build-backend = "setuptools.build_meta"
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@
"unpack_sequence",
]

__version__ = "0.0.0"
__version__ = "1.0.0"

0 comments on commit 119fa89

Please sign in to comment.