Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic CLI #97

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions carla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import pathlib

import yaml
import warnings

warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @Philoso-Fish , this one supresses the ugly:

/home/morty/dev/carla/env/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/morty/dev/carla/env/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.

warnings

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great 👍


lib_path = pathlib.Path(__file__).parent.resolve()
with open(os.path.join(lib_path, "logging.yaml"), "r") as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)


log = logging.getLogger(__name__)

from ._version import __version__
Expand Down
67 changes: 67 additions & 0 deletions carla/carla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

import click

from carla import DataCatalog, MLModelCatalog
from carla.recourse_methods import *


@click.group()
def cli():
click.echo("Hello World")


@cli.command()
@click.option(
"--data",
"data_name",
default="adult",
help="The dataset to generate counterfactuals on",
type=click.Choice(["adult", "compas"], case_sensitive=False),
)
@click.option(
"--model",
"model_name",
required=True,
default="ann",
help="The black-box model to use",
type=click.Choice(["ann", "lr"], case_sensitive=False),
)
@click.option(
"--method",
"method_name",
required=True,
default="gs",
help="The counterfactual method to run",
type=click.Choice(["gs", "face"], case_sensitive=False),
)
@click.option(
"--sample-size",
required=True,
default=5,
help="The number of factual samples from the dataset",
)
def run(data_name, method_name, model_name, sample_size):
click.echo("Run a single counterfactual method")
dataset = DataCatalog(data_name)
model = MLModelCatalog(dataset, model_name)

if method_name == "gs":
method = GrowingSpheres(model)
elif method_name == "face":
method = Face(model)
else:
raise ValueError(f"Recourse model {model_name} unknown.")

factuals = dataset.raw.sample(sample_size)
counterfactuals = method.get_counterfactuals(factuals)
click.echo(counterfactuals)


@cli.command()
def benchmark():
click.echo("Benchmark")


if __name__ == "__main__":
cli()
5 changes: 3 additions & 2 deletions carla/data/load_catalog/load_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def load_catalog(filename: str, dataset: str, keys: List[str]):
catalog = yaml.safe_load(f)

if dataset not in catalog:
raise KeyError("Dataset not in catalog.")
raise KeyError(f"Dataset '{dataset}' not in catalog.")

# TODO: Use schema validation instead of passing required keys
for key in keys:
if key not in catalog[dataset].keys():
raise KeyError("Important key {} is not in Catalog".format(key))
raise KeyError(f"Required key {key} is not in Catalog")
if catalog[dataset][key] is None:
catalog[dataset][key] = []

Expand Down
5 changes: 5 additions & 0 deletions carla/models/catalog/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from urllib.error import HTTPError
from urllib.request import urlretrieve

import logging

import tensorflow as tf
import torch


tf.get_logger().setLevel(logging.ERROR)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove more tensorflow warnings



def load_model(
name: str,
dataset: str,
Expand Down
8 changes: 5 additions & 3 deletions carla/recourse_methods/catalog/face/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import pandas as pd

Expand Down Expand Up @@ -45,9 +45,11 @@ class Face(RecourseMethod):
of the AAAI/ACM Conference on AI, Ethics, and Society (AIES)
"""

_DEFAULT_HYPERPARAMS = {"mode": None, "fraction": 0.1}
_DEFAULT_HYPERPARAMS = {"mode": "knn", "fraction": 0.1}

def __init__(self, mlmodel: MLModel, hyperparams: Dict[str, Any]) -> None:
def __init__(
self, mlmodel: MLModel, hyperparams: Optional[Dict[str, Any]] = None
) -> None:
super().__init__(mlmodel)

checked_hyperparams = merge_default_parameters(
Expand Down
3 changes: 3 additions & 0 deletions carla/recourse_methods/processing/counterfactuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def merge_default_parameters(hyperparams: Dict, default: Dict) -> Dict:
dict
Dictionary with every necessary key.
"""
if hyperparams is None:
return default

keys = default.keys()
dict_output = dict()

Expand Down
14 changes: 10 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
packages=find_packages(exclude=("test",)),
include_package_data=True,
install_requires=[
"Click==8.0.1",
"dice-ml==0.5",
"h5py==2.10.0",
"ipython",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Philoso-Fish do we really need ipyhon btw? That seems to be more for notebooks etc?

"keras==2.3.0",
"lime==0.2.0.1",
"mip==1.12.0",
"numpy==1.19.4",
Expand All @@ -45,9 +50,10 @@
"tensorflow==1.14.0",
"torch==1.7.0",
"torchvision==0.8.1",
"h5py==2.10.0",
"dice-ml==0.5",
"ipython",
"keras==2.3.0",
],
entry_points={
"console_scripts": [
"carla = carla.carla:cli",
],
},
)