From ae3ef9582087d3b332a406b8ab7e50fb93cff616 Mon Sep 17 00:00:00 2001 From: Uncertainty Baselines Team Date: Tue, 9 May 2023 15:38:31 -0700 Subject: [PATCH] Migrate files from /third_party/py/uncertainty_baselines/experimental/shoshin to /third_party/py/skai. Small fixes to adhere to /third_party/py build requirements. PiperOrigin-RevId: 530732086 --- experimental/shoshin/configs/base_config.py | 238 ---- .../shoshin/configs/celeb_a_resnet_config.py | 34 - experimental/shoshin/configs/skai_config.py | 54 - .../configs/skai_hurricane_ian_config.py | 52 - .../shoshin/configs/skai_two_tower_config.py | 59 - .../configs/waterbirds10k_resnet_config.py | 43 - .../configs/waterbirds_baseline_config.py | 44 - .../configs/waterbirds_resnet_config.py | 44 - ...terbirds_resnet_config_no_introspection.py | 46 - .../waterbirds_resnet_config_reweighting.py | 38 - .../configs/waterbirds_resnet_eval_config.py | 37 - .../configs/waterbirds_upsampling_config.py | 46 - experimental/shoshin/data.py | 1094 ----------------- experimental/shoshin/data_loader_test.py | 268 ---- experimental/shoshin/data_test.py | 140 --- experimental/shoshin/evaluate_model_lib.py | 126 -- experimental/shoshin/generate_bias_table.py | 146 --- .../shoshin/generate_bias_table_lib.py | 588 --------- experimental/shoshin/log_metrics_callback.py | 182 --- experimental/shoshin/metrics.py | 68 - experimental/shoshin/metrics_test.py | 63 - experimental/shoshin/models.py | 304 ----- experimental/shoshin/read_predictions.py | 38 - experimental/shoshin/sample_ids.py | 59 - experimental/shoshin/sampling_policies.py | 143 --- experimental/shoshin/train_tf.py | 208 ---- experimental/shoshin/train_tf_lib.py | 923 -------------- .../shoshin/train_tf_sequential_active.py | 208 ---- 28 files changed, 5293 deletions(-) delete mode 100644 experimental/shoshin/configs/base_config.py delete mode 100644 experimental/shoshin/configs/celeb_a_resnet_config.py delete mode 100644 experimental/shoshin/configs/skai_config.py delete mode 100644 experimental/shoshin/configs/skai_hurricane_ian_config.py delete mode 100644 experimental/shoshin/configs/skai_two_tower_config.py delete mode 100644 experimental/shoshin/configs/waterbirds10k_resnet_config.py delete mode 100644 experimental/shoshin/configs/waterbirds_baseline_config.py delete mode 100644 experimental/shoshin/configs/waterbirds_resnet_config.py delete mode 100644 experimental/shoshin/configs/waterbirds_resnet_config_no_introspection.py delete mode 100644 experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py delete mode 100644 experimental/shoshin/configs/waterbirds_resnet_eval_config.py delete mode 100644 experimental/shoshin/configs/waterbirds_upsampling_config.py delete mode 100644 experimental/shoshin/data.py delete mode 100644 experimental/shoshin/data_loader_test.py delete mode 100644 experimental/shoshin/data_test.py delete mode 100644 experimental/shoshin/evaluate_model_lib.py delete mode 100644 experimental/shoshin/generate_bias_table.py delete mode 100644 experimental/shoshin/generate_bias_table_lib.py delete mode 100644 experimental/shoshin/log_metrics_callback.py delete mode 100644 experimental/shoshin/metrics.py delete mode 100644 experimental/shoshin/metrics_test.py delete mode 100644 experimental/shoshin/models.py delete mode 100644 experimental/shoshin/read_predictions.py delete mode 100644 experimental/shoshin/sample_ids.py delete mode 100644 experimental/shoshin/sampling_policies.py delete mode 100644 experimental/shoshin/train_tf.py delete mode 100644 experimental/shoshin/train_tf_lib.py delete mode 100644 experimental/shoshin/train_tf_sequential_active.py diff --git a/experimental/shoshin/configs/base_config.py b/experimental/shoshin/configs/base_config.py deleted file mode 100644 index b68574b20..000000000 --- a/experimental/shoshin/configs/base_config.py +++ /dev/null @@ -1,238 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base configuration file. - -Serves as base config for custom configs, which will specify the model and -dataloader to use as well as experiment-level arguments, like whether or not to -generate a bias table or train the stage 2 model as an ensemble. -""" - -import ml_collections - - -def check_flags(config: ml_collections.ConfigDict): - """Checks validity of certain config values.""" - if not config.output_dir: - raise ValueError('output_dir has to be specified.') - if not config.data.name: - raise ValueError('config.data.name has to be specified.') - if not config.model.name: - raise ValueError('config.model.name has to be specified.') - if 100 % config.data.num_splits != 0: - raise ValueError('100 should be divisible by config.data.num_splits ', - 'because we use TFDS split by percent feature.') - if config.bias_percentile_threshold < 0 or config.bias_percentile_threshold > 100: - raise ValueError( - 'config.bias_percentile_threshold must be between 0 and 100.') - if config.bias_value_threshold and (config.bias_value_threshold < 0. or - config.bias_value_threshold > 1.): - raise ValueError('config.bias_value_threshold must be between 0. and 1.') - - -def get_data_config(): - """Get dataset config.""" - config = ml_collections.ConfigDict() - config.name = '' - config.num_classes = 2 - config.batch_size = 64 - # Number of slices into which train and val will be split. - config.num_splits = 5 - # Ratio of splits that will be considered out-of-distribution from each - # combination, e.g. when num_splits == 5 and ood_ratio == 0.4, 2 out 5 - # slices will be excluded for every combination of training data. - config.ood_ratio = 0.4 - # Indices of data splits to include in training. All by default. - config.included_splits_idx = (0, 1, 2, 3, 4) - # Subgroup IDs. Specify them in an experiment config. For example, for - # Waterbirds, the subgroup IDs might be ('0_1', '1_0') for landbirds on water - # and waterbirds on land, respectively. - config.subgroup_ids = () - # Subgroup proportions. Specify them in an experiment config. For example, for - # Waterbirds, the subgroup proportions might be (0.05, 0.05), meaning each - # subgroup will represent 5% of the dataset. - config.subgroup_proportions = () - - # Proportion of training set to sample initially. Rest is considered the pool - # for active sampling. - config.initial_sample_proportion = 0.5 - # Whether to use data splits for the creation of an ensemble or filtering. - # When filtering is used instead of creating combinations of splits each - # model is trained on a random subsample of the dataset. Split guarantees - # each point to be in the exact number of splits defined by the ood ratio. - # Filtering only guarantees this in expectation. - config.use_splits = True - config.use_filtering = False - - # The following arguments are only used when use_filtering=True - # The sum of split seed and split_id form the sampling seed for subset - # selection. - config.split_seed = 0 - config.split_id = 0 - # Seed for initial sample selection when filitering is used. - config.initial_sample_seed = 0 - # Proportion of split to size of initial training set (similar to ood_ratio - # but can have arbitrary value between 0 and 1.) - config.split_proportion = 0.7 - - # Leave one out training - config.loo_id = '' - config.loo_training = False - # Correlation strength of the minority vs majority group. This is equivalent - # to the proportion of majority group examples in the data. - config.corr_strength = 0.95 - # Specify whether to load small images or not. - config.load_small_images = True - return config - - -def get_training_config(): - """Get training config.""" - config = ml_collections.ConfigDict() - config.num_epochs = 60 - config.save_model_checkpoints = False - config.save_best_model = True - # TODO(jihyeonlee): Allow user to specify early stopping patience. - # When True, stops training when val AUC does not improve after 3 epochs. - config.early_stopping = False - config.log_to_xm = True - return config - - -def get_optimizer_config(): - """Get optimizer config.""" - config = ml_collections.ConfigDict() - # With Adam, use lr 1e-4. - config.learning_rate = 1e-5 - config.type = 'sgd' - return config - - -def get_model_config(): - """Get model config.""" - config = ml_collections.ConfigDict() - config.name = '' - config.hidden_sizes = None - config.num_channels = 3 - config.l2_regularization_factor = 0.5 - # TODO(jihyeonlee): Debug why loading ImageNet weights causes model train/val - # acc to decrease rather than improving over time. - config.load_pretrained_weights = False - # If ResNet, use config.use_pytorch_style_resnet will make adjustments to the - # TF ResNet model to match the PyTorch implementation, such as using He normal - # initialization for convolution layers. - config.use_pytorch_style_resnet = True - return config - - -def get_active_sampling_config(): - """Get model config.""" - config = ml_collections.ConfigDict() - config.sampling_score = 'ensemble_uncertainty' - config.num_samples_per_round = 50 - return config - - -def get_reweighting_config(): - """Get config for performing reweighting during training.""" - config = ml_collections.ConfigDict() - config.do_reweighting = False - config.signal = 'bias' # Options are bias, error. - # Weight that underrepresented group examples will receive. Between 0 and 1. - config.lambda_value = 0. - config.error_percentile_threshold = 0.2 - return config - - -def get_upsampling_config(): - """Get config for performing upsampling during training.""" - config = ml_collections.ConfigDict() - config.do_upsampling = False - # TODO(jihyeonlee): Add support for upsampling signal being bias or error. - config.signal = 'subgroup_label' - # Lambda determines how much each example of the group to be upsampled is - # repeated in the dataset. - config.lambda_value = 60 - return config - - -def get_evaluation_config(): - """Get config for performing introspection signal computation.""" - config = ml_collections.ConfigDict() - # A iterable tuple of epochs to compute checkpoint for. - config.signal_ckpt_epochs = () - # Number of training epochs to check for computing introspection signals. - # Used if `signal_ckpt_epochs` is empty. If 0 then compute signals based - # on the latest epoch using `tf.train.latest_checkpoint`. - config.num_signal_ckpts = 0 - return config - - -def get_config() -> ml_collections.ConfigDict: - """Get config.""" - config = ml_collections.ConfigDict() - - config.output_dir = '' - config.save_dir = '' - config.ids_dir = '' - - config.eval_splits = ('val', 'test') - - # Number of rounds of active sampling to conduct. - config.num_rounds = 4 - - # Threshold to generate bias labels. Can be specified as percentile or value. - config.bias_percentile_threshold = 80 - config.bias_value_threshold = None - config.tracin_percentile_threshold = 80 - config.tracin_value_threshold = None - config.save_bias_table = True - # Path to existing bias table to use in training the bias head. If - # unspecified, generates new one. - config.path_to_existing_bias_table = '' - # The signal used to train the bias head. - config.bias_head_prediction_signal = 'bias_label' - - config.train_bias = True - # When True, trains the stage 2 model (stage 1 is calculating bias table) - # as an ensemble of models. When True and only a single model is being - # trained, trains that model as an ensemble. - config.train_stage_2_as_ensemble = True - - # Combo index to train - config.combo_index = 0 - - # Round of acitve sampling being performed - config.round_idx = -1 - - # Whether to generate bias table (from stage one models) or prediction table - # (from stage two models) - config.generate_bias_table = True - - # Whether or not to do introspective training - config.introspective_training = True - - # Whether to save the ids used during training (for bias estimation) - config.save_train_ids = True - - config.data = get_data_config() - config.training = get_training_config() - config.optimizer = get_optimizer_config() - config.model = get_model_config() - config.active_sampling = get_active_sampling_config() - config.reweighting = get_reweighting_config() - config.upsampling = get_upsampling_config() - config.eval = get_evaluation_config() - return config diff --git a/experimental/shoshin/configs/celeb_a_resnet_config.py b/experimental/shoshin/configs/celeb_a_resnet_config.py deleted file mode 100644 index 877612c01..000000000 --- a/experimental/shoshin/configs/celeb_a_resnet_config.py +++ /dev/null @@ -1,34 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Cardiotoxicity data and MLP model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - data = config.data - data.name = 'celeb_a' - data.num_classes = 2 - - model = config.model - model.name = 'resnet50v2' - model.dropout_rate = 0.2 - - return config diff --git a/experimental/shoshin/configs/skai_config.py b/experimental/shoshin/configs/skai_config.py deleted file mode 100644 index 026ccdb86..000000000 --- a/experimental/shoshin/configs/skai_config.py +++ /dev/null @@ -1,54 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Configuration file for experiment with SKAI data and ResNet model. - -""" - - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - config.train_bias = False - config.num_rounds = 1 - config.round_idx = 0 - config.train_stage_2_as_ensemble = False - config.save_train_ids = False - - data = config.data - data.name = 'skai' - data.num_classes = 2 - # TODO(jihyeonlee): Determine what are considered subgroups in SKAI domain - # and add support for identifying by ID. - data.subgroup_ids = () - data.subgroup_proportions = () - data.initial_sample_proportion = 1. - data.tfds_dataset_name = 'skai_dataset' - data.tfds_data_dir = '/tmp/skai_dataset' - data.labeled_train_pattern = '' - data.unlabeled_train_pattern = '' - data.validation_pattern = '' - data.use_post_disaster_only = False - - model = config.model - model.name = 'resnet50v2' - model.num_channels = 6 - - return config diff --git a/experimental/shoshin/configs/skai_hurricane_ian_config.py b/experimental/shoshin/configs/skai_hurricane_ian_config.py deleted file mode 100644 index ead5c8269..000000000 --- a/experimental/shoshin/configs/skai_hurricane_ian_config.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Configuration file for experiment with SKAI data and ResNet model. - -""" - - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - config.train_bias = False - config.num_rounds = 1 - config.round_idx = 0 - config.train_stage_2_as_ensemble = False - config.save_train_ids = False - - data = config.data - data.name = 'skai' - data.num_classes = 2 - # TODO(jihyeonlee): Determine what are considered subgroups in SKAI domain - # and add support for identifying by ID. - data.subgroup_ids = () - data.subgroup_proportions = () - data.initial_sample_proportion = 1. - data.labeled_train_pattern = '' - data.unlabeled_train_pattern = '' - data.validation_pattern = '' - data.use_post_disaster_only = False - - model = config.model - model.name = 'resnet' - model.num_channels = 6 - - return config diff --git a/experimental/shoshin/configs/skai_two_tower_config.py b/experimental/shoshin/configs/skai_two_tower_config.py deleted file mode 100644 index 9e29d791f..000000000 --- a/experimental/shoshin/configs/skai_two_tower_config.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Configuration file for experiment with SKAI data and TwoTower model. - -""" - - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get two tower config.""" - config = base_config.get_config() - - config.train_bias = False - config.num_rounds = 1 - config.round_idx = 0 - config.train_stage_2_as_ensemble = False - config.save_train_ids = False - - data = config.data - data.name = 'skai' - data.num_classes = 2 - data.subgroup_ids = () - data.subgroup_proportions = () - data.initial_sample_proportion = 1. - data.tfds_dataset_name = 'skai_dataset' - data.tfds_data_dir = '/tmp/skai_dataset' - data.labeled_train_pattern = '' - data.unlabeled_train_pattern = '' - data.validation_pattern = '' - data.use_post_disaster_only = False - data.batch_size = 32 - - model = config.model - model.load_pretrained_weights = True - model.name = 'two_tower' - model.num_channels = 6 - - config.optimizer.learning_rate = 1e-4 - config.optimizer.type = 'adam' - - config.training.num_epochs = 100 - - return config diff --git a/experimental/shoshin/configs/waterbirds10k_resnet_config.py b/experimental/shoshin/configs/waterbirds10k_resnet_config.py deleted file mode 100644 index cf912f2c8..000000000 --- a/experimental/shoshin/configs/waterbirds10k_resnet_config.py +++ /dev/null @@ -1,43 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds data and ResNet model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - # Consider landbirds on water and waterbirds on land as subgroups. - config.data.subgroup_ids = () # ('0_1', '1_0') - config.data.subgroup_proportions = () # (0.04, 0.012) - config.data.initial_sample_proportion = .25 - - config.active_sampling.num_samples_per_round = 500 - config.num_rounds = 4 - - data = config.data - data.name = 'waterbirds10k' - data.num_classes = 2 - data.corr_strength = 0.95 - - model = config.model - model.name = 'resnet' - model.dropout_rate = 0.2 - - return config diff --git a/experimental/shoshin/configs/waterbirds_baseline_config.py b/experimental/shoshin/configs/waterbirds_baseline_config.py deleted file mode 100644 index 1e17d3aff..000000000 --- a/experimental/shoshin/configs/waterbirds_baseline_config.py +++ /dev/null @@ -1,44 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds baseline experiment.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - config.train_bias = False - config.num_rounds = 1 - config.round_idx = 0 - config.train_stage_2_as_ensemble = False - config.save_train_ids = False - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - data.subgroup_ids = () # ('0_1', '1_0') - data.subgroup_proportions = () # (0.04, 0.012) - data.initial_sample_proportion = 1. - - model = config.model - model.name = 'resnet50v2' - - # Set to 0 to compute introspection signal based on the best epoch. - config.eval.num_signal_ckpts = 0 - return config diff --git a/experimental/shoshin/configs/waterbirds_resnet_config.py b/experimental/shoshin/configs/waterbirds_resnet_config.py deleted file mode 100644 index 4b992dcfb..000000000 --- a/experimental/shoshin/configs/waterbirds_resnet_config.py +++ /dev/null @@ -1,44 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds data and ResNet model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - # Consider landbirds on water and waterbirds on land as subgroups. - config.data.subgroup_ids = () # ('0_1', '1_0') - config.data.subgroup_proportions = () # (0.04, 0.012) - config.data.initial_sample_proportion = .25 - - config.active_sampling.num_samples_per_round = 500 - config.num_rounds = 4 - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - - model = config.model - model.name = 'resnet50v2' - model.dropout_rate = 0.2 - - # Set to 0 to compute introspection signal based on the best epoch. - config.eval.num_signal_ckpts = 0 - return config diff --git a/experimental/shoshin/configs/waterbirds_resnet_config_no_introspection.py b/experimental/shoshin/configs/waterbirds_resnet_config_no_introspection.py deleted file mode 100644 index 3f60abb9c..000000000 --- a/experimental/shoshin/configs/waterbirds_resnet_config_no_introspection.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds data and ResNet model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - # Consider landbirds on water and waterbirds on land as subgroups. - config.data.subgroup_ids = () # ('0_1', '1_0') - config.data.subgroup_proportions = () # (0.04, 0.012) - config.data.initial_sample_proportion = .25 - - config.active_sampling.num_samples_per_round = 500 - config.active_sampling.sample_score = 'ensemble_uncertainty' - config.num_rounds = 4 - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - - model = config.model - model.name = 'resnet50v2' - model.dropout_rate = 0.2 - config.train_bias = False - config.generate_bias_table = False - config.introspective_training = False - - return config diff --git a/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py b/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py deleted file mode 100644 index ab519359f..000000000 --- a/experimental/shoshin/configs/waterbirds_resnet_config_reweighting.py +++ /dev/null @@ -1,38 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds data and ResNet model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - - model = config.model - model.name = 'resnet50v2' - model.dropout_rate = 0.2 - - config.train_bias = False - reweighting = config.reweighting - reweighting.do_reweighting = True - - return config diff --git a/experimental/shoshin/configs/waterbirds_resnet_eval_config.py b/experimental/shoshin/configs/waterbirds_resnet_eval_config.py deleted file mode 100644 index 1a73b0aec..000000000 --- a/experimental/shoshin/configs/waterbirds_resnet_eval_config.py +++ /dev/null @@ -1,37 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds data and ResNet model.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - # Consider landbirds on water and waterbirds on land as subgroups. - config.data.subgroup_ids = ('0_1', '1_0') - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - - model = config.model - model.name = 'resnet50v2' - model.dropout_rate = 0.2 - - return config diff --git a/experimental/shoshin/configs/waterbirds_upsampling_config.py b/experimental/shoshin/configs/waterbirds_upsampling_config.py deleted file mode 100644 index 27759e670..000000000 --- a/experimental/shoshin/configs/waterbirds_upsampling_config.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for experiment with Waterbirds baseline experiment.""" - -import ml_collections -from configs import base_config # local file import from experimental.shoshin - - -def get_config() -> ml_collections.ConfigDict: - """Get mlp config.""" - config = base_config.get_config() - - config.train_bias = False - config.num_rounds = 1 - config.round_idx = 0 - config.train_stage_2_as_ensemble = False - config.save_train_ids = False - - data = config.data - data.name = 'waterbirds' - data.num_classes = 2 - data.subgroup_ids = () # ('0_1', '1_0') - data.subgroup_proportions = () # (0.04, 0.012) - data.initial_sample_proportion = 1. - - model = config.model - model.name = 'resnet50v2' - - config.upsampling.do_upsampling = True - - # Set to 0 to compute introspection signal based on the best epoch. - config.eval.num_signal_ckpts = 0 - return config diff --git a/experimental/shoshin/data.py b/experimental/shoshin/data.py deleted file mode 100644 index e0328235e..000000000 --- a/experimental/shoshin/data.py +++ /dev/null @@ -1,1094 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Library of dataloaders to use in Introspective Active Sampling. - -This file contains a library of dataloaders that return three features for each -example: example_id, input feature, and label. The example_id is a unique ID -that will be used to keep track of the bias label for that example. The input -feature will vary depending on the type of data (feature vector, image, etc.), -and the label is specific to the main task. -""" - -import collections -import dataclasses -import os -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -import uuid - -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds - -DATASET_REGISTRY = {} -DATA_DIR = '/tmp/data' -_WATERBIRDS_DATA_DIR = '' -_WATERBIRDS10K_DATA_DIR = '' -_WATERBIRDS_TRAIN_PATTERN = '' -_WATERBIRDS_VALIDATION_PATTERN = '' -# Smaller subsample for testing. -_WATERBIRDS_TRAIN_SAMPLE_PATTERN = '' -_WATERBIRDS_TEST_PATTERN = '' -_WATERBIRDS_NUM_SUBGROUP = 4 -_WATERBIRDS_TRAIN_SIZE = 4780 -_WATERBIRDS10K_TRAIN_SIZE = 9549 -_WATERBIRDS10K_SUPPORTED_CORR_STRENGTH = (0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, - 0.95) - -# TODO(dvij,martinstrobel): Set Celeb-A number of subgroups. -_CELEB_A_NUM_SUBGROUP = 2 - -RESNET_IMAGE_SIZE = 224 -CROP_PADDING = 32 - - -def register_dataset(name: str): - """Provides decorator to register functions that return dataset.""" - - def save(dataset_builder): - DATASET_REGISTRY[name] = dataset_builder - return dataset_builder - - return save - - -def get_dataset(name: str): - """Retrieves dataset based on name.""" - if name not in DATASET_REGISTRY: - raise ValueError( - f'Unknown dataset: {name}\nPossible choices: {DATASET_REGISTRY.keys()}') - return DATASET_REGISTRY[name] - - -@dataclasses.dataclass -class Dataloader: - num_subgroups: int # Number of subgroups in data. - subgroup_sizes: Dict[str, int] # Number of examples by subgroup. - train_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. - val_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. - train_ds: tf.data.Dataset # Dataset with all the train splits combined. - num_train_examples: Optional[int] = 0 # Number of training examples. - worst_group_label: Optional[int] = 2 # Label of the worst subgroup. - train_sample_ds: Optional[tf.data.Dataset] = None # Subsample of train set. - eval_ds: Optional[Dict[ - str, - tf.data.Dataset]] = None # Validation and any additional test datasets. - - -def get_subgroup_sizes(dataloader: tf.data.Dataset) -> Dict[str, int]: - """Gets the number examples of each subgroup.""" - subgroup_sizes = dict( - collections.Counter( - dataloader.map(lambda x: x['subgroup_label']).as_numpy_iterator() - ) - ) - return {str(key): val for key, val in subgroup_sizes.items()} - - -def upsample_subgroup( - dataset: tf.data.Dataset, - lambda_value: int = 60, - signal: str = 'subgroup_label', - subgroup_sizes: Optional[Dict[str, int]] = None, -) -> tf.data.Dataset: - """Creates dataset that has upsampled subgroup. - - Args: - dataset: Dataset to be transformed. - lambda_value: Number of times each example of the underrepresented group - should be repeated in dataset. - signal: String for the value that determines whether or not an example - belongs to an underrepresented group. - subgroup_sizes: Dictionary mapping subgroup index to size. - - Returns: - Transformed dataset. - """ - if signal != 'subgroup_label': - raise ValueError( - 'Upsampling with signals other than subgroup_label is not supported.' - ) - # In this case, we assume that the data has ground-truth subgroup labels. - # Identify the group that is smallest and upsample it. - if not subgroup_sizes: - raise ValueError( - 'When using ground-truth subgroup label as upsampling signal,' - ' dictionary of subgroup sizes must be available.' - ) - examples_by_subgroup = {} - smallest_subgroup_label = '' - smallest_subgroup_size = -1 - for subgroup_label in subgroup_sizes.keys(): - - def filter_subgroup(x, label=subgroup_label): - return tf.math.equal( - x['subgroup_label'], tf.strings.to_number(label, tf.int64) - ) - - examples_by_subgroup[subgroup_label] = dataset.filter(filter_subgroup) - if smallest_subgroup_size == -1: - smallest_subgroup_label = subgroup_label - smallest_subgroup_size = subgroup_sizes[subgroup_label] - elif subgroup_sizes[subgroup_label] < smallest_subgroup_size: - smallest_subgroup_label = subgroup_label - smallest_subgroup_size = subgroup_sizes[subgroup_label] - examples_by_subgroup[smallest_subgroup_label] = examples_by_subgroup[ - smallest_subgroup_label - ].repeat(lambda_value) - subgroup_sizes[smallest_subgroup_label] *= lambda_value - dataset_size = sum(subgroup_sizes.values()) - weights = [ - float(subgroup_sizes[subgroup_label]) / dataset_size - for subgroup_label in subgroup_sizes - ] - upsampled_dataset = tf.data.Dataset.sample_from_datasets( - examples_by_subgroup.values(), - weights=weights, - stop_on_empty_dataset=False, - ) - return upsampled_dataset - - -def apply_batch(dataloader, batch_size): - """Apply batching to dataloader.""" - dataloader.train_splits = [ - data.batch(batch_size) for data in dataloader.train_splits - ] - dataloader.val_splits = [ - data.batch(batch_size) for data in dataloader.val_splits - ] - num_splits = len(dataloader.train_splits) - train_ds = gather_data_splits( - list(range(num_splits)), dataloader.train_splits) - val_ds = gather_data_splits(list(range(num_splits)), dataloader.val_splits) - dataloader.train_ds = train_ds - dataloader.eval_ds['val'] = val_ds - for (k, v) in dataloader.eval_ds.items(): - if k != 'val': - dataloader.eval_ds[k] = v.batch(batch_size) - return dataloader - - -def gather_data_splits( - slice_idx: List[int], - dataset: Union[tf.data.Dataset, List[tf.data.Dataset]]) -> tf.data.Dataset: - """Gathers slices of a split dataset based on passed indices.""" - data_slice = dataset[slice_idx[0]] - for idx in slice_idx[1:]: - data_slice = data_slice.concatenate(dataset[idx]) - return data_slice - - -def get_ids_from_dataset(dataset: tf.data.Dataset) -> List[str]: - """Gets example ids from dataset.""" - ids_list = list(dataset.map(lambda x: x['example_id']).as_numpy_iterator()) - if isinstance(ids_list[0], np.ndarray): - new_ids_list = [] - for ids in ids_list: - new_ids_list += ids.tolist() - return new_ids_list - else: - return ids_list - - -def create_ids_table(dataloader: Dataloader, - initial_sample_proportion: float, - initial_sample_seed: int, - split_proportion: float, - split_num: int, - split_seed: int, - training: bool) -> tf.lookup.StaticHashTable: - """Creates a hash table representing ids in each each split. - - Args: - dataloader: Dataloader for the unfilterd dataset. - initial_sample_proportion: Proportion of larger subset to initial dataset. - initial_sample_seed: Seed to select the larger subset (identical for all - splits.) - split_proportion: Proportion of split to larger subset. - split_num: Number of split. Used to set the sampling seed for the split - subset. - split_seed: Split seed second part of the sampling seed. - training: Whether to create a training set or a validation set. - - Returns: - A hash table mapping ids to membership in the filtered dataset. - """ - ids = get_ids_from_dataset(dataloader.train_ds) - initial_sample_size = int(len(ids) * initial_sample_proportion) - np.random.seed(initial_sample_seed) - subset_ids = np.random.choice(ids, initial_sample_size, replace=False) - # ids_dir is populated by the sample_and_split_ids function above - tf.compat.v1.logging.info('Seed number %d', split_num + split_seed) - np.random.seed(split_num + split_seed) - ids_i = np.random.choice( - subset_ids, int(split_proportion * initial_sample_size), replace=False) - tf.compat.v1.logging.info('Subset size %d', len(ids_i)) - if not training: - ids_i = subset_ids[~np.isin(subset_ids, ids_i)] - keys = tf.convert_to_tensor(ids_i, dtype=tf.string) - values = tf.ones(shape=keys.shape, dtype=tf.int64) - init = tf.lookup.KeyValueTensorInitializer( - keys=keys, - values=values, - key_dtype=tf.string, - value_dtype=tf.int64) - return tf.lookup.StaticHashTable(init, default_value=0) - - -def filter_set( - dataloader: Dataloader, - initial_sample_proportion: float, - initial_sample_seed: int, - split_proportion: float, - split_id: int, - split_seed: int, - training: bool, -) -> tf.data.Dataset: - """Filters training set to create subsets of arbitrary size. - - First, a set of initial_sample_proportion is selected from which the different - training an validation sets are sampled according to split_proportion. The - indiviudal splits are used to train an ensemble of models where each model is - trained on an equally sized training set. - - Args: - dataloader: Dataloader for the unfilterd dataset. - initial_sample_proportion: Proportion of larger subset to initial dataset. - initial_sample_seed: Seed to select the larger subset (identical for all - splits.) - split_proportion: Proportion of split to larger subset. - split_id: Number of split. Used to set the sampling seed for the split - subset. - split_seed: Split seed second part of the sampling seed. - training: Whether to create a training set or a validation set. - - Returns: - A filtered dataset. - """ - filter_table = create_ids_table( - dataloader, - initial_sample_proportion=initial_sample_proportion, - initial_sample_seed=initial_sample_seed, - split_proportion=split_proportion, - split_num=split_id, - split_seed=split_seed, - training=training, - ) - return dataloader.train_ds.filter( - lambda datapoint: filter_table.lookup(datapoint['example_id']) == 1 - ) - - -class WaterbirdsDataset(tfds.core.GeneratorBasedBuilder): - """DatasetBuilder for Waterbirds dataset.""" - - VERSION = tfds.core.Version('1.0.0') - RELEASE_NOTES = { - '1.0.0': 'Initial release.', - } - - def __init__(self, - subgroup_ids: List[str], - subgroup_proportions: Optional[List[float]] = None, - train_dataset_size: int = _WATERBIRDS_TRAIN_SIZE, - source_data_dir: str = _WATERBIRDS_DATA_DIR, - include_train_sample: bool = True, - **kwargs): - super(WaterbirdsDataset, self).__init__(**kwargs) - self.subgroup_ids = subgroup_ids - self.train_dataset_size = train_dataset_size - # Path to original TFRecords to sample data from. - self.source_data_dir = source_data_dir - self.include_train_sample = include_train_sample - if subgroup_proportions: - self.subgroup_proportions = subgroup_proportions - else: - self.subgroup_proportions = [1.] * len(subgroup_ids) - - def _info(self) -> tfds.core.DatasetInfo: - """Dataset metadata (homepage, citation,...).""" - return tfds.core.DatasetInfo( - builder=self, - features=tfds.features.FeaturesDict({ - 'example_id': - tfds.features.Text(), - 'subgroup_id': - tfds.features.Text(), - 'subgroup_label': - tfds.features.ClassLabel(num_classes=4), - 'input_feature': - tfds.features.Image( - shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 3)), - 'label': - tfds.features.ClassLabel(num_classes=2), - 'place': - tfds.features.ClassLabel(num_classes=2), - 'image_filename': - tfds.features.Text(), - 'place_filename': - tfds.features.Text(), - }), - ) - - def _decode_and_center_crop(self, image_bytes: tf.Tensor): - """Crops to center of image with padding then scales RESNET_IMAGE_SIZE.""" - shape = tf.io.extract_jpeg_shape(image_bytes) - image_height = shape[0] - image_width = shape[1] - - padded_center_crop_size = tf.cast( - ((RESNET_IMAGE_SIZE / (RESNET_IMAGE_SIZE + CROP_PADDING)) * - tf.cast(tf.math.minimum(image_height, image_width), tf.float32)), - tf.int32) - - offset_height = ((image_height - padded_center_crop_size) + 1) // 2 - offset_width = ((image_width - padded_center_crop_size) + 1) // 2 - crop_window = tf.stack([ - offset_height, offset_width, padded_center_crop_size, - padded_center_crop_size - ]) - image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) - return image - - def _preprocess_image(self, image_bytes: tf.Tensor) -> tf.Tensor: - """Preprocesses the given image for evaluation. - - Args: - image_bytes: `Tensor` representing an image binary of arbitrary size. - - Returns: - A preprocessed image `Tensor`. - """ - image = self._decode_and_center_crop(image_bytes) - # No data augmentation, like in JTT paper. - # image = tf.image.random_flip_left_right(image) - image = tf.image.resize([image], [RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE], - method='nearest')[0] - return image - - def _get_subgroup_label(self, label: tf.Tensor, - place: tf.Tensor) -> tf.Tensor: - """Determines subgroup label for given combination of label and place. - - 0 for landbirds on land, 1 for waterbirds on water, 2 for landbirds - on water, and 3 for waterbirds on land. - - Args: - label: Class label (waterbird or landbird). - place: Place label (water or land). - - Returns: - TF Tensor containing subgroup label (integer). - """ - if tf.math.equal(label, place): - return label - else: - if tf.math.equal(label, 1): # and place == 0, so waterbird on land - return tf.constant(2, dtype=tf.int32) - else: - return tf.constant(3, dtype=tf.int32) - - def _dataset_parser(self, value): - """Parse a Waterbirds record from a serialized string Tensor.""" - keys_to_features = { - 'image/filename/raw': tf.io.FixedLenFeature([], tf.string, ''), - 'image/class/place': tf.io.FixedLenFeature([], tf.int64, -1), - 'image/encoded': tf.io.FixedLenFeature([], tf.string, ''), - 'image/filename/places': tf.io.FixedLenFeature([], tf.string, ''), - 'image/class/label': tf.io.FixedLenFeature([], tf.int64, -1), - } - - parsed = tf.io.parse_single_example(value, keys_to_features) - - image = self._preprocess_image(image_bytes=parsed['image/encoded']) - label = tf.cast(parsed['image/class/label'], dtype=tf.int32) - place = tf.cast(parsed['image/class/place'], dtype=tf.int32) - image_filename = tf.cast(parsed['image/filename/raw'], dtype=tf.string) - place_filename = tf.cast(parsed['image/filename/places'], dtype=tf.string) - subgroup_id = tf.strings.join( - [tf.strings.as_string(label), - tf.strings.as_string(place)], - separator='_') - subgroup_label = self._get_subgroup_label(label, place) - - return image_filename, { - 'example_id': image_filename, - 'label': label, - 'place': place, - 'input_feature': image, - 'image_filename': image_filename, - 'place_filename': place_filename, - 'subgroup_id': subgroup_id, - 'subgroup_label': subgroup_label - } - - def _split_generators(self, dl_manager: tfds.download.DownloadManager): - """Download the data and define splits.""" - split_generators = { - 'train': - self._generate_examples( - os.path.join(self.source_data_dir, _WATERBIRDS_TRAIN_PATTERN), - is_training=True), - 'validation': - self._generate_examples( - os.path.join(self.source_data_dir, - _WATERBIRDS_VALIDATION_PATTERN)), - 'test': - self._generate_examples( - os.path.join(self.source_data_dir, _WATERBIRDS_TEST_PATTERN)), - } - - if self.include_train_sample: - split_generators['train_sample'] = self._generate_examples( - os.path.join(self.source_data_dir, _WATERBIRDS_TRAIN_SAMPLE_PATTERN)) - - return split_generators - - def _generate_examples(self, - file_pattern: str, - is_training: Optional[bool] = False - ) -> Iterator[Tuple[str, Dict[str, Any]]]: - """Generator of examples for each split.""" - dataset = tf.data.Dataset.list_files(file_pattern, shuffle=is_training) - - def _fetch_dataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) - return dataset - - # Reads the data from disk in parallel. - dataset = dataset.interleave( - _fetch_dataset, - cycle_length=16, - num_parallel_calls=tf.data.experimental.AUTOTUNE) - - # Parses and pre-processes the data in parallel. - dataset = dataset.map(self._dataset_parser, num_parallel_calls=2) - - # Prefetches overlaps in-feed with training. - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) - - if is_training: - options = tf.data.Options() - options.experimental_deterministic = False - dataset = dataset.with_options(options) - - # Prepare initial training set. - # Pre-computed dataset size or large number >= estimated dataset size. - dataset_size = self.train_dataset_size - dataset = dataset.shuffle(dataset_size) - sampled_datasets = [] - remaining_proportion = 1. - for idx, subgroup_id in enumerate(self.subgroup_ids): - - def filter_fn_subgroup(image_filename, feats): - _ = image_filename - return tf.math.equal(feats['subgroup_id'], subgroup_id) # pylint: disable=cell-var-from-loop - - subgroup_dataset = dataset.filter(filter_fn_subgroup) - subgroup_sample_size = int(dataset_size * - self.subgroup_proportions[idx]) - subgroup_dataset = subgroup_dataset.take(subgroup_sample_size) - sampled_datasets.append(subgroup_dataset) - remaining_proportion -= self.subgroup_proportions[idx] - - def filter_fn_remaining(image_filename, feats): - _ = image_filename - return tf.reduce_all( - tf.math.not_equal(feats['subgroup_id'], self.subgroup_ids)) - - remaining_dataset = dataset.filter(filter_fn_remaining) - remaining_sample_size = int(dataset_size * remaining_proportion) - remaining_dataset = remaining_dataset.take(remaining_sample_size) - sampled_datasets.append(remaining_dataset) - - dataset = sampled_datasets[0] - for ds in sampled_datasets[1:]: - dataset = dataset.concatenate(ds) - dataset = dataset.shuffle(dataset_size) - - return dataset.as_numpy_iterator() - - -class Waterbirds10kDataset(WaterbirdsDataset): - """DatasetBuilder for Waterbirds10K dataset.""" - - VERSION = tfds.core.Version('1.0.0') - RELEASE_NOTES = { - '1.0.0': 'Initial release.', - } - - def __init__(self, - subgroup_ids: List[str], - subgroup_proportions: Optional[List[float]] = None, - corr_strength: float = 0.95, - train_dataset_size: int = _WATERBIRDS10K_TRAIN_SIZE, - source_data_parent_dir: str = _WATERBIRDS10K_DATA_DIR, - include_train_sample: bool = False, - **kwargs): - if corr_strength not in _WATERBIRDS10K_SUPPORTED_CORR_STRENGTH: - raise ValueError( - f'corr_strength {corr_strength} not supported. ' - f'Should be one of: {_WATERBIRDS10K_SUPPORTED_CORR_STRENGTH}') - - # Makes the source data directory based on `corr_strength`. - # The final data directory should follow the format - # `{parent_dir}/corr_strength_{corr_strength}`. - corr_strength_name = str(int(corr_strength * 100)) - source_data_folder_name = f'corr_strength_{corr_strength_name}' - source_data_dir = os.path.join(source_data_parent_dir, - source_data_folder_name) - - if not tf.io.gfile.exists(source_data_dir): - raise ValueError(f'Required data dir `{source_data_dir}` not exist.') - else: - tf.compat.v1.logging.info(f'Loading from `{source_data_dir}`.') - - self.corr_strength = corr_strength - super().__init__(subgroup_ids, subgroup_proportions, train_dataset_size, - source_data_dir, include_train_sample, **kwargs) - - -@dataclasses.dataclass -class SkaiDatasetConfig(tfds.core.BuilderConfig): - """Configuration for SKAI datasets. - - Any of the attributes can be left blank if they don't exist. - - Attributes: - name: Name of the dataset. - labeled_train_pattern: Pattern for labeled training examples tfrecords. - labeled_test_pattern: Pattern for labeled test examples tfrecords. - unlabeled_pattern: Pattern for unlabeled examples tfrecords. - use_post_disaster_only: Whether to use post-disaster imagery only rather - than full 6-channel stacked image input. - """ - labeled_train_pattern: str = '' - labeled_test_pattern: str = '' - unlabeled_pattern: str = '' - use_post_disaster_only: bool = False - image_size: int = RESNET_IMAGE_SIZE - max_examples: int = 0 - load_small_images: bool = True - - -def _decode_and_resize_image( - image_bytes: tf.Tensor, size: int) -> tf.Tensor: - return tf.image.resize( - tf.io.decode_image( - image_bytes, - channels=3, - expand_animations=False, - dtype=tf.float32, - ), - [size, size], - ) - - -class SkaiDataset(tfds.core.GeneratorBasedBuilder): - """TFDS dataset for SKAI. - - Example usage: - import tensorflow_datasets.public_api as tfds - from skai import dataset - - ds = tfds.load('skai_dataset', builder_kwargs={ - 'config': SkaiDatasetConfig( - name='example', - labeled_train_pattern='gs://path/to/train_labeled_examples.tfrecord', - labeled_test_pattern='gs://path/to/test_labeled_examples.tfrecord', - unlabeled_pattern='gs://path/to/unlabeled_examples-*.tfrecord') - }) - labeled_train_dataset = ds['labeled_train'] - labeled_test_dataset = ds['labeled_test'] - unlabeled_test_dataset = ds['unlabeled'] - """ - - VERSION = tfds.core.Version('1.0.0') - - - def __init__(self, - subgroup_ids: Optional[List[str]] = None, - subgroup_proportions: Optional[List[float]] = None, - include_train_sample: bool = True, - **kwargs): - super(SkaiDataset, self).__init__(**kwargs) - self.subgroup_ids = subgroup_ids - # Path to original TFRecords to sample data from. - if self.subgroup_ids: - if subgroup_proportions: - self.subgroup_proportions = subgroup_proportions - else: - self.subgroup_proportions = [1.] * len(subgroup_ids) - else: - self.subgroup_proportions = None - self.include_train_sample = include_train_sample - - def _info(self): - # TODO(jihyeonlee): Change label and subgroup_label to - # tfds.features.ClassLabel. - num_channels = 3 if self.builder_config.use_post_disaster_only else 6 - input_shape = ( - self.builder_config.image_size, - self.builder_config.image_size, - num_channels, - ) - if self.builder_config.load_small_images: - input_type = tfds.features.FeaturesDict({ - 'large_image': tfds.features.Tensor( - shape=input_shape, dtype=tf.float32 - ), - 'small_image': tfds.features.Tensor( - shape=input_shape, dtype=tf.float32 - ), - }) - else: - input_type = tfds.features.FeaturesDict({ - 'large_image': tfds.features.Tensor( - shape=input_shape, dtype=tf.float32 - ), - }) - return tfds.core.DatasetInfo( - builder=self, - description='Skai', - features=tfds.features.FeaturesDict({ - 'input_feature': input_type, - 'example_id': tfds.features.Text(), - 'coordinates': tfds.features.Tensor(shape=(2,), dtype=tf.float32), - 'label': tfds.features.Tensor(shape=(), dtype=tf.int64), - 'string_label': tfds.features.Text(), - 'subgroup_label': tfds.features.Tensor(shape=(), dtype=tf.int64), - }), - ) - - def _split_generators(self, dl_manager: tfds.download.DownloadManager): - splits = {} - if self.builder_config.labeled_train_pattern: - splits['labeled_train'] = self._generate_examples( - self.builder_config.labeled_train_pattern - ) - if self.builder_config.labeled_test_pattern: - splits['labeled_test'] = self._generate_examples( - self.builder_config.labeled_test_pattern - ) - if self.builder_config.unlabeled_pattern: - splits['unlabeled'] = self._generate_examples( - self.builder_config.unlabeled_pattern - ) - return splits - - def _decode_record(self, record_bytes): - - example = tf.io.parse_single_example( - record_bytes, - { - 'coordinates': tf.io.FixedLenFeature([2], dtype=tf.float32), - 'encoded_coordinates': tf.io.FixedLenFeature([], dtype=tf.string), - 'example_id': tf.io.FixedLenFeature([], dtype=tf.string), - 'pre_image_png_large': tf.io.FixedLenFeature([], dtype=tf.string), - 'pre_image_png': tf.io.FixedLenFeature([], dtype=tf.string), - 'post_image_png_large': tf.io.FixedLenFeature( - [], dtype=tf.string - ), - 'post_image_png': tf.io.FixedLenFeature([], dtype=tf.string), - 'label': tf.io.FixedLenFeature([], dtype=tf.float32), - 'string_label': tf.io.FixedLenFeature( - [], dtype=tf.string, default_value='' - ), - }, - ) - - features = { - 'input_feature': {} - } - large_image_concat = _decode_and_resize_image( - example['post_image_png_large'], self.builder_config.image_size - ) - small_image_concat = _decode_and_resize_image( - example['post_image_png'], self.builder_config.image_size - ) - - if not self.builder_config.use_post_disaster_only: - before_image = _decode_and_resize_image( - example['pre_image_png_large'], self.builder_config.image_size - ) - before_image_small = _decode_and_resize_image( - example['pre_image_png'], self.builder_config.image_size - ) - large_image_concat = tf.concat( - [before_image, large_image_concat], axis=-1 - ) - small_image_concat = tf.concat( - [before_image_small, small_image_concat], axis=-1 - ) - features['input_feature']['large_image'] = large_image_concat - if self.builder_config.load_small_images: - features['input_feature']['small_image'] = small_image_concat - features['label'] = tf.cast(example['label'], tf.int64) - features['example_id'] = example['example_id'] - features['subgroup_label'] = features['label'] - features['coordinates'] = example['coordinates'] - features['string_label'] = example['string_label'] - return features - - def _generate_examples(self, pattern: str): - if not pattern: - return - paths = tf.io.gfile.glob(pattern) - ds = tf.data.TFRecordDataset(paths).map( - self._decode_record, num_parallel_calls=tf.data.AUTOTUNE) - if self.builder_config.max_examples: - ds = ds.take(self.builder_config.max_examples) - for features in ds.as_numpy_iterator(): - yield uuid.uuid4().hex, features - - -@register_dataset('waterbirds') -def get_waterbirds_dataset(num_splits: int, - initial_sample_proportion: float, - subgroup_ids: List[str], - subgroup_proportions: List[float], - tfds_dataset_name: str = 'waterbirds_dataset', - include_train_sample: bool = True, - data_dir: str = DATA_DIR, - upsampling_lambda: int = 1, - upsampling_signal: str = 'subgroup_label', - **additional_builder_kwargs) -> Dataloader: - """Returns datasets for training, validation, and possibly test sets. - - Args: - num_splits: Integer for number of slices of the dataset. - initial_sample_proportion: Float for proportion of entire training dataset - to sample initially before active sampling begins. - subgroup_ids: List of strings of IDs indicating subgroups. - subgroup_proportions: List of floats indicating proportion that each - subgroup should take in initial training dataset. - tfds_dataset_name: The name of the tfd dataset to load from. - include_train_sample: Whether to include the `train_sample` split. - data_dir: Default data directory to store the sampled waterbirds data. - upsampling_lambda: Number of times subgroup examples should be repeated. - upsampling_signal: Signal to use to determine subgroup to upsample. - **additional_builder_kwargs: Additional keyword arguments to data builder. - - Returns: - A tuple containing the split training data, split validation data, the - combined training dataset, and a dictionary mapping evaluation dataset names - to their respective combined datasets. - """ - split_size_in_pct = int(100 * initial_sample_proportion / num_splits) - reduced_datset_sz = int(100 * initial_sample_proportion) - builder_kwargs = { - 'subgroup_ids': subgroup_ids, - 'subgroup_proportions': subgroup_proportions, - 'include_train_sample': include_train_sample, - **additional_builder_kwargs - } - val_splits = tfds.load( - tfds_dataset_name, - split=[ - f'validation[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False) - - train_splits = tfds.load( - tfds_dataset_name, - split=[ - f'train[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False) - - test_ds = tfds.load( - tfds_dataset_name, - split='test', - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False, - with_info=False) - - train_sample = () - if include_train_sample: - train_sample = tfds.load( - tfds_dataset_name, - split='train_sample', - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False, - with_info=False) - - train_ds = gather_data_splits(list(range(num_splits)), train_splits) - val_ds = gather_data_splits(list(range(num_splits)), val_splits) - eval_datasets = { - 'val': val_ds, - 'test': test_ds, - } - subgroup_sizes = get_subgroup_sizes(train_ds) - if upsampling_lambda > 1: - train_ds = upsample_subgroup( - train_ds, upsampling_lambda, upsampling_signal, subgroup_sizes - ) - - return Dataloader( - _WATERBIRDS_NUM_SUBGROUP, - subgroup_sizes, - train_splits, - val_splits, - train_ds, - num_train_examples=_WATERBIRDS_TRAIN_SIZE, - worst_group_label=2, # 1_0, waterbirds on land. - train_sample_ds=train_sample, - eval_ds=eval_datasets) - - -@register_dataset('waterbirds10k') -def get_waterbirds10k_dataset( - num_splits: int, - initial_sample_proportion: float, - subgroup_ids: List[str], - subgroup_proportions: List[float], - corr_strength: float = 0.95, - data_dir: str = DATA_DIR, - upsampling_lambda: int = 1, - upsampling_signal: str = 'subgroup_label', -) -> Dataloader: - """Returns datasets for Waterbirds 10K.""" - # Create unique `waterbirds10k` directory for each correlation strength. - data_folder_name = int(corr_strength * 100) - data_folder_name = f'waterbirds10k_corr_strength_{data_folder_name}' - data_dir = os.path.join(data_dir, data_folder_name) - - return get_waterbirds_dataset( - num_splits, - initial_sample_proportion, - subgroup_ids, - subgroup_proportions, - tfds_dataset_name='waterbirds10k_dataset', - include_train_sample=False, - corr_strength=corr_strength, - data_dir=data_dir, - upsampling_lambda=upsampling_lambda, - upsampling_signal=upsampling_signal) - - -@register_dataset('celeb_a') -def get_celeba_dataset( - num_splits: int, - initial_sample_proportion: float, - subgroup_ids: List[str], - subgroup_proportions: List[float], - upsampling_lambda: int = 1, - upsampling_signal: str = 'subgroup_label', -) -> Dataloader: - """Returns datasets for training, validation, and possibly test sets. - - Args: - num_splits: Integer for number of slices of the dataset. - initial_sample_proportion: Float for proportion of entire training dataset - to sample initially before active sampling begins. - subgroup_ids: List of strings of IDs indicating subgroups. - subgroup_proportions: List of floats indicating proportion that each - subgroup should take in initial training dataset. - upsampling_lambda: Number of times subgroup examples should be repeated. - upsampling_signal: Signal to use to determine subgroup to upsample. - - Returns: - A tuple containing the split training data, split validation data, the - combined training dataset, and a dictionary mapping evaluation dataset names - to their respective combined datasets. - """ - del subgroup_proportions, subgroup_ids - read_config = tfds.ReadConfig() - read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key - split_size_in_pct = int(100 * initial_sample_proportion / num_splits) - reduced_dataset_sz = int(100 * initial_sample_proportion) - train_splits = tfds.load( - 'celeb_a', - read_config=read_config, - split=[ - f'train[:{k}%]+train[{k+split_size_in_pct}%:]' - for k in range(0, reduced_dataset_sz, split_size_in_pct) - ], - data_dir=DATA_DIR, - try_gcs=False, - as_supervised=True) - val_splits = tfds.load( - 'celeb_a', - read_config=read_config, - split=[ - f'validation[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_dataset_sz, split_size_in_pct) - ], - data_dir=DATA_DIR, - try_gcs=False, - as_supervised=True) - train_sample = tfds.load( - 'celeb_a', - split='train_sample', - data_dir=DATA_DIR, - try_gcs=False, - as_supervised=True, - with_info=False) - - test_ds = tfds.load( - 'celeb_a', - split='test', - data_dir=DATA_DIR, - try_gcs=False, - as_supervised=True, - with_info=False) - - train_ds = gather_data_splits(list(range(num_splits)), train_splits) - val_ds = gather_data_splits(list(range(num_splits)), val_splits) - eval_datasets = { - 'val': val_ds, - 'test': test_ds, - } - subgroup_sizes = get_subgroup_sizes(train_ds) - if upsampling_lambda > 1: - train_ds = upsample_subgroup( - train_ds, upsampling_lambda, upsampling_signal, subgroup_sizes - ) - - return Dataloader( - _CELEB_A_NUM_SUBGROUP, - subgroup_sizes, - train_splits, - val_splits, - train_ds, - train_sample_ds=train_sample, - eval_ds=eval_datasets) - - -@register_dataset('skai') -def get_skai_dataset(num_splits: int, - initial_sample_proportion: float, - subgroup_ids: List[str], - subgroup_proportions: List[float], - tfds_dataset_name: str = 'skai_dataset', - data_dir: str = DATA_DIR, - include_train_sample: Optional[bool] = False, - labeled_train_pattern: str = '', - unlabeled_train_pattern: str = '', - validation_pattern: str = '', - use_post_disaster_only: Optional[bool] = False, - upsampling_lambda: int = 1, - upsampling_signal: str = 'subgroup_label', - load_small_images: bool = False, - **additional_builder_kwargs) -> Dataloader: - """Returns datasets for training, validation, and possibly test sets. - - Args: - num_splits: Integer for number of slices of the dataset. - initial_sample_proportion: Float for proportion of entire training dataset - to sample initially before active sampling begins. - subgroup_ids: List of strings of IDs indicating subgroups. - subgroup_proportions: List of floats indicating proportion that each - subgroup should take in initial training dataset. - tfds_dataset_name: The name of the tfd dataset to load from. - data_dir: Default data directory to store the sampled data. - include_train_sample: Whether to include the `train_sample` split. - labeled_train_pattern: File pattern for labeled training data. - unlabeled_train_pattern: File pattern for unlabeled training data. - validation_pattern: File pattern for validation data. - use_post_disaster_only: Whether to use post-disaster imagery only rather - than full 6-channel stacked image input. - upsampling_lambda: Number of times subgroup examples should be repeated. - upsampling_signal: Signal to use to determine subgroup to upsample. - load_small_images: A flag controls loading small images or not. - **additional_builder_kwargs: Additional keyword arguments to data builder. - - Returns: - A tuple containing the split training data, split validation data, the - combined training dataset, and a dictionary mapping evaluation dataset names - to their respective combined datasets. - """ - - builder_kwargs = { - 'subgroup_ids': subgroup_ids, - 'subgroup_proportions': subgroup_proportions, - 'include_train_sample': include_train_sample, - **additional_builder_kwargs - } - if '/' not in tfds_dataset_name: - # No named config variant specified, so provide the config explicitly. - # pylint: disable=unexpected-keyword-arg - builder_kwargs['config'] = SkaiDatasetConfig( - name='skai_dataset', - labeled_train_pattern=labeled_train_pattern, - labeled_test_pattern=validation_pattern, - unlabeled_pattern=unlabeled_train_pattern, - use_post_disaster_only=use_post_disaster_only, - load_small_images=load_small_images, - ) - # pylint: enable=unexpected-keyword-arg - split_size_in_pct = int(100 * initial_sample_proportion / num_splits) - reduced_datset_sz = int(100 * initial_sample_proportion) - - val_splits = tfds.load( - tfds_dataset_name, - split=[ - f'labeled_test[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False) - - train_splits = tfds.load( - tfds_dataset_name, - split=[ - f'labeled_train[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False) - - # TODO(jihyeonlee): Utilize unlabeled data. - - # No separate test set, so use validation for now. - test_ds = tfds.load( - tfds_dataset_name, - split='labeled_test', - data_dir=data_dir, - builder_kwargs=builder_kwargs, - try_gcs=False, - with_info=False) - - train_ds = gather_data_splits(list(range(num_splits)), train_splits) - val_ds = gather_data_splits(list(range(num_splits)), val_splits) - eval_datasets = { - 'val': val_ds, - 'test': test_ds, - } - subgroup_sizes = get_subgroup_sizes(train_ds) - if upsampling_lambda > 1: - train_ds = upsample_subgroup( - train_ds, upsampling_lambda, upsampling_signal, subgroup_sizes - ) - return Dataloader( - 2, - subgroup_sizes, - train_splits, - val_splits, - train_ds, - train_sample_ds=None, - eval_ds=eval_datasets) diff --git a/experimental/shoshin/data_loader_test.py b/experimental/shoshin/data_loader_test.py deleted file mode 100644 index bd2486b4d..000000000 --- a/experimental/shoshin/data_loader_test.py +++ /dev/null @@ -1,268 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for data loaders.""" - -import os -import tempfile -from typing import List -import numpy as np -import tensorflow as tf - -import data # local file import from experimental.shoshin -from google3.testing.pybase import googletest - -RESNET_IMAGE_SIZE = 224 - - -def _make_temp_dir() -> str: - return tempfile.mkdtemp(dir=os.environ.get('TEST_TMPDIR')) - - -def _make_serialized_image(size: int, pixel_value: int) -> bytes: - image = np.ones((size, size, 3), dtype=np.uint8) * pixel_value - return tf.io.encode_png(image).numpy() - - -def _make_example( - example_id: str, - longitude: float, - latitude: float, - encoded_coordinates: str, - label: float, - string_label: float, - patch_size: int, - large_patch_size: int, - before_pixel_value: int, - after_pixel_value: int, -) -> tf.train.Example: - example = tf.train.Example() - example.features.feature['example_id'].bytes_list.value.append( - example_id.encode() - ) - example.features.feature['coordinates'].float_list.value.extend( - (longitude, latitude) - ) - example.features.feature['encoded_coordinates'].bytes_list.value.append( - encoded_coordinates.encode() - ) - example.features.feature['label'].float_list.value.append(label) - example.features.feature['string_label'].bytes_list.value.append( - string_label.encode() - ) - example.features.feature['pre_image_png'].bytes_list.value.append( - _make_serialized_image(patch_size, before_pixel_value) - ) - example.features.feature['post_image_png'].bytes_list.value.append( - _make_serialized_image(patch_size, after_pixel_value) - ) - example.features.feature['pre_image_png_large'].bytes_list.value.append( - _make_serialized_image(large_patch_size, before_pixel_value) - ) - example.features.feature['post_image_png_large'].bytes_list.value.append( - _make_serialized_image(large_patch_size, after_pixel_value) - ) - return example - - -def _write_tfrecord(examples: List[tf.train.Example], path: str) -> None: - with tf.io.TFRecordWriter(path) as file_writer: - for example in examples: - file_writer.write(example.SerializeToString()) - - -def _create_test_data(): - examples_dir = _make_temp_dir() - labeled_train_path = os.path.join( - examples_dir, 'train_labeled_examples.tfrecord') - labeled_test_path = os.path.join( - examples_dir, 'test_labeled_examples.tfrecord') - unlabeled_path = os.path.join( - examples_dir, 'unlabeled_examples.tfrecord') - - _write_tfrecord([ - _make_example('1st', 0, 0, 'A0', 0, 'no_damage', 64, 256, 0, 255), - _make_example('2nd', 0, 1, 'A1', 0, 'no_damage', 64, 256, 0, 255), - _make_example('3rd', 0, 2, 'A2', 1, 'major_damage', 64, 256, 0, 255), - ], labeled_train_path) - - _write_tfrecord([ - _make_example('4th', 1, 0, 'B0', 0, 'no_damage', 64, 256, 0, 255), - ], labeled_test_path) - - _write_tfrecord([ - _make_example('5th', 2, 0, 'C0', -1, 'bad_example', 64, 256, 0, 255), - _make_example('6th', 2, 1, 'C1', -1, 'bad_example', 64, 256, 0, 255), - _make_example('7th', 2, 2, 'C2', -1, 'bad_example', 64, 256, 0, 255), - _make_example('8th', 2, 3, 'C3', -1, 'bad_example', 64, 256, 0, 255), - ], unlabeled_path) - - return labeled_train_path, labeled_test_path, unlabeled_path - - -class DataLoaderTest(googletest.TestCase): - def setUp(self): - super().setUp() - - labeled_train_path, labeled_test_path, unlabeled_path = _create_test_data() - self.labeled_train_path = labeled_train_path - self.labeled_test_path = labeled_test_path - self.unlabeled_path = unlabeled_path - - def test_get_skai_dataset_post_only(self): - dataset_builder = data.get_dataset('skai') - - kwargs = { - 'labeled_train_pattern': self.labeled_train_path, - 'unlabeled_train_pattern': self.unlabeled_path, - 'validation_pattern': self.labeled_test_path, - 'use_post_disaster_only': True, - 'data_dir': _make_temp_dir(), - } - - dataloader = dataset_builder( - 1, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **kwargs - ) - ds = dataloader.train_ds - features = next(ds.as_numpy_iterator()) - self.assertIn('input_feature', features) - self.assertIn('large_image', features['input_feature']) - input_feature = features['input_feature']['large_image'] - self.assertEqual( - input_feature.shape, - (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 3), - ) - self.assertEqual(input_feature.dtype, np.float32) - np.testing.assert_equal(input_feature, 1.0) - - def test_get_skai_dataset_pre_post(self): - dataset_builder = data.get_dataset('skai') - - kwargs = { - 'labeled_train_pattern': self.labeled_train_path, - 'unlabeled_train_pattern': self.unlabeled_path, - 'validation_pattern': self.labeled_test_path, - 'use_post_disaster_only': False, - 'data_dir': _make_temp_dir(), - } - - dataloader = dataset_builder( - 1, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **kwargs - ) - ds = dataloader.train_ds - features = next(ds.as_numpy_iterator()) - self.assertIn('input_feature', features) - self.assertIn('large_image', features['input_feature']) - input_feature = features['input_feature']['large_image'] - self.assertEqual( - input_feature.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) - ) - self.assertEqual(input_feature.dtype, np.float32) - np.testing.assert_equal(input_feature[:, :, :3], 0.0) - np.testing.assert_equal(input_feature[:, :, 3:], 1.0) - - def test_get_skai_dataset_small_images(self): - dataset_builder = data.get_dataset('skai') - - kwargs = { - 'labeled_train_pattern': self.labeled_train_path, - 'unlabeled_train_pattern': self.unlabeled_path, - 'validation_pattern': self.labeled_test_path, - 'use_post_disaster_only': False, - 'data_dir': _make_temp_dir(), - 'load_small_images': True, - } - - dataloader = dataset_builder( - 1, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **kwargs - ) - ds = dataloader.train_ds - features = next(ds.as_numpy_iterator()) - self.assertIn('input_feature', features) - self.assertIn('small_image', features['input_feature']) - self.assertIn('large_image', features['input_feature']) - small_image = features['input_feature']['small_image'] - large_image = features['input_feature']['large_image'] - self.assertEqual( - small_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) - ) - self.assertEqual(small_image.dtype, np.float32) - np.testing.assert_equal(small_image[:, :, :3], 0.0) - np.testing.assert_equal(small_image[:, :, 3:], 1.0) - - self.assertEqual( - large_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) - ) - self.assertEqual(small_image.dtype, np.float32) - np.testing.assert_equal(large_image[:, :, :3], 0.0) - np.testing.assert_equal(large_image[:, :, 3:], 1.0) - - def test_upsample_subgroup(self): - dataset_builder = data.get_dataset('skai') - - kwargs = { - 'labeled_train_pattern': self.labeled_train_path, - 'unlabeled_train_pattern': self.unlabeled_path, - 'validation_pattern': self.labeled_test_path, - 'use_post_disaster_only': False, - 'data_dir': _make_temp_dir(), - } - - dataloader = dataset_builder( - 1, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **kwargs) - ds = dataloader.train_ds - subgroup_sizes = data.get_subgroup_sizes(ds) - self.assertEqual(subgroup_sizes['0'], 2) - self.assertEqual(subgroup_sizes['1'], 1) - lambda_value = 10 - upsampled_ds = data.upsample_subgroup( - ds, lambda_value, 'subgroup_label', subgroup_sizes - ) - self.assertLen( - list( - upsampled_ds.filter( - lambda x: tf.math.equal(x['subgroup_label'], 0) - ).as_numpy_iterator() - ), - 2, - ) - self.assertLen( - list( - upsampled_ds.filter( - lambda x: tf.math.equal(x['subgroup_label'], 1) - ).as_numpy_iterator() - ), - 1 * lambda_value, - ) - - -if __name__ == '__main__': - googletest.main() diff --git a/experimental/shoshin/data_test.py b/experimental/shoshin/data_test.py deleted file mode 100644 index a7bcb1eb9..000000000 --- a/experimental/shoshin/data_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for data sets.""" - -import os -import tempfile -from typing import List -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds - -import data # local file import from experimental.shoshin - - -def _make_temp_dir() -> str: - return tempfile.mkdtemp(dir=os.environ.get('TEST_TMPDIR')) - - -def _make_serialized_image(size: int) -> bytes: - image = np.random.randint(0, 255, size=(size, size, 3), dtype=np.uint8) - return tf.io.encode_png(image).numpy() - - -def _make_example( - example_id: str, - longitude: float, - latitude: float, - encoded_coordinates: str, - label: float, - string_label: float, - patch_size: int, - large_patch_size: int, -) -> tf.train.Example: - example = tf.train.Example() - example.features.feature['example_id'].bytes_list.value.append( - example_id.encode() - ) - example.features.feature['coordinates'].float_list.value.extend( - (longitude, latitude) - ) - example.features.feature['encoded_coordinates'].bytes_list.value.append( - encoded_coordinates.encode() - ) - example.features.feature['label'].float_list.value.append(label) - example.features.feature['string_label'].bytes_list.value.append( - string_label.encode() - ) - example.features.feature['pre_image_png'].bytes_list.value.append( - _make_serialized_image(patch_size) - ) - example.features.feature['post_image_png'].bytes_list.value.append( - _make_serialized_image(patch_size) - ) - example.features.feature['pre_image_png_large'].bytes_list.value.append( - _make_serialized_image(large_patch_size) - ) - example.features.feature['post_image_png_large'].bytes_list.value.append( - _make_serialized_image(large_patch_size) - ) - return example - - -def _write_tfrecord(examples: List[tf.train.Example], path: str) -> None: - with tf.io.TFRecordWriter(path) as file_writer: - for example in examples: - file_writer.write(example.SerializeToString()) - - -def _create_test_data(): - examples_dir = _make_temp_dir() - labeled_train_path = os.path.join( - examples_dir, 'train_labeled_examples.tfrecord') - labeled_test_path = os.path.join( - examples_dir, 'test_labeled_examples.tfrecord') - unlabeled_path = os.path.join( - examples_dir, 'unlabeled_examples.tfrecord') - - _write_tfrecord([ - _make_example('1st', 0, 0, 'A0', 0, 'no_damage', 64, 256), - _make_example('2nd', 0, 1, 'A1', 0, 'no_damage', 64, 256), - _make_example('3rd', 0, 2, 'A2', 1, 'major_damage', 64, 256), - ], labeled_train_path) - - _write_tfrecord([ - _make_example('4th', 1, 0, 'B0', 0, 'no_damage', 64, 256), - ], labeled_test_path) - - _write_tfrecord([ - _make_example('5th', 2, 0, 'C0', -1, 'bad_example', 64, 256), - _make_example('6th', 2, 1, 'C1', -1, 'bad_example', 64, 256), - _make_example('7th', 2, 2, 'C2', -1, 'bad_example', 64, 256), - _make_example('8th', 2, 3, 'C3', -1, 'bad_example', 64, 256), - ], unlabeled_path) - - return labeled_train_path, labeled_test_path, unlabeled_path - - -class SkaiDatasetTest(tfds.testing.DatasetBuilderTestCase): - """Tests for Skai dataset.""" - - DATASET_CLASS = data.SkaiDataset - SPLITS = { - 'labeled_train': 3, - 'labeled_test': 1, - 'unlabeled': 4 - } - EXAMPLE_DIR = _make_temp_dir() - BUILDER_CONFIG_NAMES_TO_TEST = ['test_config'] - SKIP_TF1_GRAPH_MODE = True - - @classmethod - def setUpClass(cls): - super().setUpClass() - - labeled_train_path, labeled_test_path, unlabeled_path = _create_test_data() - - cls.DATASET_CLASS.BUILDER_CONFIGS = [ - data.SkaiDatasetConfig( - name='test_config', - labeled_train_pattern=labeled_train_path, - labeled_test_pattern=labeled_test_path, - unlabeled_pattern=unlabeled_path) - ] - - -if __name__ == '__main__': - tfds.testing.test_main() diff --git a/experimental/shoshin/evaluate_model_lib.py b/experimental/shoshin/evaluate_model_lib.py deleted file mode 100644 index 871cf1361..000000000 --- a/experimental/shoshin/evaluate_model_lib.py +++ /dev/null @@ -1,126 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Library for evaluating active sampling. -""" - -import os -from typing import Mapping - -import numpy as np -import pandas as pd -import tensorflow as tf -import data # local file import from experimental.shoshin - - -def merge_subgroup_labels( - ds: tf.data.Dataset, - table: pd.DataFrame, - batch_size: int, -): - """Merge table with subroup labels from ds.""" - ids = np.concatenate(list( - ds.map(lambda example: example['example_id']).batch( - batch_size).as_numpy_iterator())).tolist() - ids = list(map(lambda x: x.decode('UTF-8'), ids)) - subgroup_labels = np.concatenate(list( - ds.map(lambda example: example['subgroup_label']).batch( - batch_size).as_numpy_iterator())).tolist() - labels = np.concatenate(list( - ds.map(lambda example: example['label']).batch( - batch_size).as_numpy_iterator())).tolist() - df_a = pd.DataFrame({ - 'example_id': ids, 'subgroup_label': subgroup_labels, - 'label': labels}) - table = table[table['example_id'].isin(ids)] - return pd.merge(table, df_a, on=['example_id']) - - -def _process_table(table: pd.DataFrame, prediction: bool): - """Modify table to have cleaned up example ids and predictions.""" - table['example_id'] = table['example_id'].map( - lambda x: eval(x).decode('UTF-8')) # pylint:disable=eval-used - if prediction: - prediction_label_cols = filter(lambda x: 'label' in x, table.columns) - prediction_bias_cols = filter(lambda x: 'bias' in x, table.columns) - table['bias'] = table[prediction_bias_cols].mean(axis=1) - table['label_prediction'] = table[prediction_label_cols].mean(axis=1) - return table - - -def evaluate_active_sampling( - num_rounds: int, - output_dir: str, - dataloader: data.Dataloader, - batch_size: int, - num_subgroups: int, - ) -> pd.DataFrame: - """Evaluates model for subgroup representation vs number of rounds.""" - round_idx = [] - subgroup_ids = [] - num_samples = [] - prob_representation = [] - for idx in range(num_rounds): - ds = dataloader.train_ds - bias_table = pd.read_csv( - os.path.join( - os.path.join(output_dir, f'round_{idx}'), 'bias_table.csv')) - predictions_merge = merge_subgroup_labels(ds, bias_table, batch_size) - for subgroup_id in range(num_subgroups): - prob_i = (predictions_merge['subgroup_label'] - == subgroup_id).sum() / len(predictions_merge) - round_idx.append(idx) - subgroup_ids.append(subgroup_id) - num_samples.append(len(predictions_merge)) - prob_representation.append(prob_i) - return pd.DataFrame({ - 'num_samples': num_samples, - 'prob_representation': prob_representation, - 'round_idx': round_idx, - 'subgroup_ids': subgroup_ids, - }) - - -def evaluate_model( - round_idx: int, - output_dir: str, - dataloader: data.Dataloader, - batch_size: int, - ) -> Mapping[str, pd.DataFrame]: - """Evaluates model for subgroup representation vs number of rounds.""" - bias_table = pd.read_csv( - os.path.join( - os.path.join(output_dir, f'round_{round_idx}'), 'bias_table.csv')) - bias_table = _process_table(bias_table, False) - predictions_table = pd.read_csv( - os.path.join( - os.path.join(output_dir, f'round_{round_idx}'), - 'predictions_table.csv')) - predictions_table = _process_table(predictions_table, True) - predictions_merge = {} - predictions_merge['train_bias'] = merge_subgroup_labels( - dataloader.train_ds, bias_table, batch_size) - predictions_merge['train_predictions'] = merge_subgroup_labels( - dataloader.train_ds, predictions_table, batch_size) - for (ds_name, ds) in dataloader.eval_ds.items(): - predictions_table = _process_table(pd.read_csv( - os.path.join( - os.path.join(output_dir, f'round_{round_idx}'), - f'predictions_table_{ds_name}.csv')), True) - predictions_merge[f'{ds_name}_predictions'] = merge_subgroup_labels( - ds, predictions_table, batch_size) - predictions_merge[f'{ds_name}_bias'] = merge_subgroup_labels( - ds, bias_table, batch_size) - return predictions_merge diff --git a/experimental/shoshin/generate_bias_table.py b/experimental/shoshin/generate_bias_table.py deleted file mode 100644 index 6dc7da373..000000000 --- a/experimental/shoshin/generate_bias_table.py +++ /dev/null @@ -1,146 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Binary executable for generating bias label table. - -This file serves as a binary to calculate bias values and create a lookup table -that maps from example ID to bias label. - - -# pylint: enable=line-too-long - -Note: In output_dir, models trained on different splits of data must already -exist and be present in directory. -""" - -import os - -from absl import app -from absl import flags -from ml_collections import config_flags -import data # local file import from experimental.shoshin -import generate_bias_table_lib # local file import from experimental.shoshin -import models # local file import from experimental.shoshin -import sampling_policies # local file import from experimental.shoshin -import train_tf_lib # local file import from experimental.shoshin -from configs import base_config # local file import from experimental.shoshin - - -FLAGS = flags.FLAGS -config_flags.DEFINE_config_file('config') - - -def main(_) -> None: - - config = FLAGS.config - base_config.check_flags(config) - combos_dir = os.path.join(config.output_dir, - generate_bias_table_lib.COMBOS_SUBDIR) - model_params = models.ModelTrainingParameters( - model_name=config.model.name, - train_bias=config.train_bias, - num_classes=config.data.num_classes, - num_subgroups=0, - subgroup_sizes={}, - num_epochs=config.training.num_epochs, - learning_rate=config.optimizer.learning_rate, - ) - - dataset_builder = data.get_dataset(config.data.name) - if config.generate_bias_table: - # Loads data. - if config.round_idx == 0: - dataloader = dataset_builder(config.data.num_splits, - config.data.initial_sample_proportion, - config.data.subgroup_ids, - config.data.subgroup_proportions,) - else: - dataloader = dataset_builder(config.data.num_splits, 1, - config.data.subgroup_ids, - config.data.subgroup_proportions,) - # Filter each split to only have examples from example_ids_table - dataloader.train_splits = [ - dataloader.train_ds.filter( - generate_bias_table_lib.filter_ids_fn(ids_tab)) for - ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)] - dataloader = data.apply_batch(dataloader, config.data.batch_size) - model_params.num_subgroups = dataloader.num_subgroups - - # Selects training epochs to compute introspection signals from. - ckpt_epochs = config.eval.signal_ckpt_epochs - if not ckpt_epochs: - # If `signal_ckpt_epochs` is not provided via eval config, compute the - # list of epochs number to load checkpoint from based on - # `config.eval.num_signal_ckpts`. If `num_signal_ckpts=0`, then only the - # latest epoch will be loaded. - ckpt_epochs = generate_bias_table_lib.compute_signal_epochs( - config.eval.num_signal_ckpts, - num_total_epochs=config.training.num_epochs) - - # Computes introspection signal for every checkpoint epoch. - for ckpt_epoch in ckpt_epochs: - # Loads model. - trained_models = train_tf_lib.load_trained_models( - combos_dir, model_params, ckpt_epoch=ckpt_epoch) - - # Generates table. - _ = generate_bias_table_lib.get_example_id_to_bias_label_table( - dataloader=dataloader, - combos_dir=combos_dir, - trained_models=trained_models, - num_splits=config.data.num_splits, - bias_percentile_threshold=config.bias_percentile_threshold, - tracin_percentile_threshold=config.tracin_percentile_threshold, - bias_value_threshold=config.bias_value_threshold, - tracin_value_threshold=config.tracin_value_threshold, - save_dir=config.output_dir, - ckpt_epoch=ckpt_epoch, - save_table=True) - else: - # Generates prediction table for all splits. - dataloader = dataset_builder( - config.data.num_splits, 1, config.data.subgroup_ids, - config.data.subgroup_proportions) - dataloader = data.apply_batch(dataloader, config.data.batch_size) - model_params.num_subgroups = dataloader.num_subgroups - - # Loads model. Here we use the best checkpoint for prediction table by - # setting `ckpt_epoch=-1`. - trained_models = train_tf_lib.load_trained_models( - combos_dir, model_params, ckpt_epoch=-1) - - # Generates table. - if_compute_tracin = config.active_sampling.sampling_score == 'tracin' - _ = generate_bias_table_lib.get_example_id_to_predictions_table( - dataloader=dataloader, - trained_models=trained_models, - has_bias=config.train_bias, - split='train', - save_dir=config.save_dir, - save_table=True, - compute_tracin=if_compute_tracin) - for split_name in config.eval_splits: - _ = generate_bias_table_lib.get_example_id_to_predictions_table( - dataloader=dataloader, - trained_models=trained_models, - has_bias=config.train_bias, - split=split_name, - save_dir=config.save_dir, - save_table=True, - compute_tracin=if_compute_tracin) - - -if __name__ == '__main__': - app.run(main) diff --git a/experimental/shoshin/generate_bias_table_lib.py b/experimental/shoshin/generate_bias_table_lib.py deleted file mode 100644 index c1b91b8d2..000000000 --- a/experimental/shoshin/generate_bias_table_lib.py +++ /dev/null @@ -1,588 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for Introspective Active Sampling. - -Library of utilities for the Introspecive Active Sampling method. Includes a -function to generate a table mapping example ID to bias label, which can be -used to train the bias output head. -""" - -import os -from typing import Dict, List, Optional, Tuple, Any - -from absl import logging -import numpy as np -import pandas as pd -import tensorflow as tf - -import data # local file import from experimental.shoshin - - -EXAMPLE_ID_KEY = 'example_id' -BIAS_LABEL_KEY = 'bias_label' -TRACIN_LABEL_KEY = 'tracin_label' -TRACIN_SCORE_KEY = 'tracin_score' -PREDICTION_KEY = 'prediction' -# Subdirectory for models trained on splits in FLAGS.output_dir. -COMBOS_SUBDIR = 'combos' -CHECKPOINT_SUBDIR = 'checkpoints/' - - -def compute_signal_epochs(num_signal_ckpts: int, num_total_epochs: int): - """Computes the epochs to compute introspection signals.""" - if num_signal_ckpts <= 0: - # This will inform the `train_tf_lib.load_trained_models` to just use the - # best checkpoint. - return [-1] - - # Computes the epochs in log scale. - log_epochs = np.linspace(0, np.log(num_total_epochs), num=num_signal_ckpts) - epochs = np.ceil(np.exp(log_epochs)).astype(int) - epochs = list(np.unique(epochs)) - # Still allows the computation of the best checkpoint. - epochs.append(-1) - return epochs - - -def load_existing_bias_table( - path_to_table: str, - signal: Optional[str] = BIAS_LABEL_KEY) -> tf.lookup.StaticHashTable: - """Loads bias table from file.""" - df = pd.read_csv(path_to_table) - key_tensor = np.array([eval(x).decode('UTF-8') for # pylint:disable=eval-used - x in df[EXAMPLE_ID_KEY].to_list()]) - init = tf.lookup.KeyValueTensorInitializer( - keys=tf.convert_to_tensor( - key_tensor, dtype=tf.string), - values=tf.convert_to_tensor( - df[signal].to_numpy(), dtype=tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - return tf.lookup.StaticHashTable(init, default_value=0) - - -def get_example_id_to_bias_label_table( - dataloader: data.Dataloader, - combos_dir: str, - trained_models: List[tf.keras.Model], - num_splits: int, - bias_percentile_threshold: int, - tracin_percentile_threshold: int, - bias_value_threshold: Optional[float] = None, - tracin_value_threshold: Optional[float] = None, - save_dir: Optional[str] = None, - ckpt_epoch: int = -1, - save_table: Optional[bool] = True) -> tf.lookup.StaticHashTable: - """Generates a lookup table mapping example ID to bias label. - - Args: - dataloader: Dataclass object containing training and validation data. - combos_dir: Directory of model checkpoints by the combination of data splits - used in training. - trained_models: List of trained models. - num_splits: Total number of slices that data was split into. - bias_percentile_threshold: Integer between 0 and 100 representing the - percentile of bias values to give a label of 1 (and 0 for all others). - Given a vector V of length N, the q-th percentile of V is the value q/100 - of the way from the minimum to the maximum in a sorted copy of V. - tracin_percentile_threshold: Integer between 0 and 100 representing the - percentile of tracin values to give a label of 1 (and 0 for all others). - Given a vector V of length N, the q-th percentile of V is the value q/100 - of the way from the minimum to the maximum in a sorted copy of V. - bias_value_threshold: Float representing the bias value threshold, above - which examples will receive a bias label of 1 (and 0 if below). Use - percentile by default. - tracin_value_threshold: Float representing the tracin value threshold, above - which examples will receive a tracin label of 1 (and 0 if below). Use - percentile by default. - save_dir: Directory in which bias table will be saved as CSV. - ckpt_epoch: The training epoch where the models in `trained_models` are - loaded from. It only impacts the file name of the bias table. - save_table: Boolean for whether or not to save table. - - Returns: - A lookup table mapping example ID to bias label and additional meta data - (i.e., target label, subgroup label, and introspection signals). - """ - is_train_all = [] - example_ids_all = [] - target_labels_all = [] - groups_labels_all = [] - - bias_labels_all = [] - bias_values_all = [] - vars_values_all = [] - nois_values_all = [] - gap_values_all = [] - error_values_all = [] - id_tracin_values_all = [] - ood_tracin_values_all = [] - tracin_labels_all = [] - for split_idx in range(num_splits): - # For each split of data, - # 1. Get the models that included this split (as in-domain training data). - # 2. Get the models that excluded this split (as out-of-distribution data). - # 3. Calculate the bias value and, using the threshold, bias label. - id_predictions_all = [] - ood_predictions_all = [] - id_tracin_values_splits = [] - ood_tracin_values_splits = [] - - # Collects target and place labels. - labels = list(dataloader.train_splits[split_idx].map( - lambda example: example['label']).as_numpy_iterator()) - labels += list(dataloader.val_splits[split_idx].map( - lambda example: example['label']).as_numpy_iterator()) - labels = np.concatenate(labels) - target_labels_all.append(labels) - - group_labels = list(dataloader.train_splits[split_idx].map( - lambda example: example['subgroup_label']).as_numpy_iterator()) - group_labels += list(dataloader.val_splits[split_idx].map( - lambda example: example['subgroup_label']).as_numpy_iterator()) - group_labels = np.concatenate(group_labels) - groups_labels_all.append(group_labels) - - # Collects in-sample and out-of-sample predictions. - for combo_idx, combo in enumerate(tf.io.gfile.listdir(combos_dir)): - splits_in_combo = [int(split_idx) for split_idx in combo.split('_')] - model = trained_models[combo_idx] - if split_idx in splits_in_combo: - # Identifies in-sample model and collects its predictions. - id_predictions_train = model.predict( - dataloader.train_splits[split_idx].map( - lambda example: example['input_feature'])) - id_predictions_val = model.predict( - dataloader.val_splits[split_idx].map( - lambda example: example['input_feature'])) - id_predictions = tf.concat( - [id_predictions_train['main'], id_predictions_val['main']], axis=0) - id_predictions = tf.gather_nd( - id_predictions, tf.expand_dims(labels, axis=1), batch_dims=1) - id_predictions_all.append(id_predictions) - _, tracin_values_train, _ = calculate_tracin_values( - dataloader.train_splits[split_idx], [model], has_bias=True) - _, tracin_values_val, _ = calculate_tracin_values( - dataloader.val_splits[split_idx], [model], has_bias=True) - id_tracin_values = tf.concat([tracin_values_train, tracin_values_val], - axis=0) - id_tracin_values_splits.append(id_tracin_values) - else: - # Identifies out-of-sample model and collects its predictions. - ood_predictions_train = model.predict( - dataloader.train_splits[split_idx].map( - lambda example: example['input_feature'])) - ood_predictions_val = model.predict( - dataloader.val_splits[split_idx].map( - lambda example: example['input_feature'])) - ood_predictions = tf.concat( - [ood_predictions_train['main'], ood_predictions_val['main']], - axis=0) - ood_predictions = tf.gather_nd( - ood_predictions, tf.expand_dims(labels, axis=1), batch_dims=1) - ood_predictions_all.append(ood_predictions) - _, tracin_values_train, _ = calculate_tracin_values( - dataloader.train_splits[split_idx], [model], has_bias=True) - _, tracin_values_val, _ = calculate_tracin_values( - dataloader.val_splits[split_idx], [model], has_bias=True) - ood_tracin_values = tf.concat([tracin_values_train, tracin_values_val], - axis=0) - ood_tracin_values_splits.append(ood_tracin_values) - - # Collects example ids and is_train indicators. - # NB: The extracted example id are byte strings. - example_ids_train = list(dataloader.train_splits[split_idx].map( - lambda example: example['example_id']).as_numpy_iterator()) - example_ids_val = list(dataloader.val_splits[split_idx].map( - lambda example: example['example_id']).as_numpy_iterator()) - example_ids = example_ids_train + example_ids_val - example_ids = np.concatenate(example_ids) - example_ids_all.append(example_ids) - - is_train = tf.concat([ - tf.ones(len(np.concatenate(example_ids_train)), dtype=tf.int64), - tf.zeros(len(np.concatenate(example_ids_val)), dtype=tf.int64) - ], axis=0) - is_train_all.append(is_train) - - # Computes in-sample and out-of-sample predictions and bias values. - id_predictions_avg = np.average(np.stack(id_predictions_all), axis=0) - ood_predictions_avg = np.average(np.stack(ood_predictions_all), axis=0) - id_tracin_values_avg = np.average(np.stack(id_tracin_values_splits), axis=0) - ood_tracin_values_avg = np.average( - np.stack(ood_tracin_values_splits), axis=0) - tracin_values_avg = np.average( - np.stack(ood_tracin_values_splits + id_tracin_values_splits), axis=0) - bias_values = np.absolute( - np.subtract(id_predictions_avg, ood_predictions_avg)) - vars_values = np.std(np.stack(ood_predictions_all), axis=0) - # Since the `id_predictions_avg` is the predictive probability for the - # target class. The `noise` is simply the distance between the predicted - # probability and the true probability of target class (i.e., 1.). - nois_values = np.absolute(np.subtract(1., id_predictions_avg)) - error_values = np.average(np.subtract(1., ood_predictions_all), axis=0) - gap_values = np.average( - np.absolute(np.subtract(id_predictions_avg[None, :], - np.stack(ood_predictions_all))), axis=0) - - # Calculates bias labels using value threshold by default. - # If percentile is specified, use percentile instead. - if bias_percentile_threshold: - threshold = np.percentile(bias_values, bias_percentile_threshold) - else: - threshold = bias_value_threshold - bias_labels = tf.math.greater(bias_values, threshold) - # Calculates tracin labels using value threshold by default. - # If percentile is specified, use percentile instead. - if tracin_percentile_threshold: - tracin_threshold = np.percentile(tracin_values_avg, - tracin_percentile_threshold) - else: - tracin_threshold = tracin_value_threshold - tracin_labels = tf.math.greater(tracin_values_avg, tracin_threshold) - - bias_labels_all.append(bias_labels) - tracin_labels_all.append(tracin_labels) - bias_values_all.append(bias_values) - id_tracin_values_all.append(id_tracin_values_avg) - ood_tracin_values_all.append(ood_tracin_values_avg) - vars_values_all.append(vars_values) - nois_values_all.append(nois_values) - - gap_values_all.append(gap_values) - error_values_all.append(error_values) - - is_train_all = np.concatenate(is_train_all) - example_ids_all = np.concatenate(example_ids_all) - target_labels_all = np.concatenate(target_labels_all) - groups_labels_all = np.concatenate(groups_labels_all) - id_tracin_values_all = np.concatenate(id_tracin_values_all) - ood_tracin_values_all = np.concatenate(ood_tracin_values_all) - tracin_labels_all = np.squeeze(np.concatenate(tracin_labels_all)) - - bias_labels_all = np.squeeze(np.concatenate(bias_labels_all)) - bias_values_all = np.squeeze(np.concatenate(bias_values_all)) - vars_values_all = np.squeeze(np.concatenate(vars_values_all)) - nois_values_all = np.squeeze(np.concatenate(nois_values_all)) - gap_values_all = np.squeeze(np.concatenate(gap_values_all)) - error_values_all = np.squeeze(np.concatenate(error_values_all)) - - logging.info('# of examples: %s', example_ids_all.shape[0]) - logging.info('# of target_label: %s', target_labels_all.shape[0]) - logging.info('# of groups_label: %s', groups_labels_all.shape[0]) - logging.info('# of bias_label: %s', bias_labels_all.shape[0]) - logging.info('# of non-zero bias labels: %s', - tf.math.count_nonzero(bias_labels_all).numpy()) - - logging.info('# of bias: %s', bias_values_all.shape[0]) - logging.info('# of variance: %s', vars_values_all.shape[0]) - logging.info('# of noise: %s', nois_values_all.shape[0]) - logging.info('# of gap: %s', nois_values_all.shape[0]) - logging.info('# of noise: %s', gap_values_all.shape[0]) - logging.info('# of error: %s', error_values_all.shape[0]) - logging.info('# of is_train: %s', is_train_all.shape[0]) - logging.info('# of train examples: %s', - tf.math.count_nonzero(is_train_all).numpy()) - - if save_table: - df = pd.DataFrame({ - 'example_id': example_ids_all, - 'target_label': target_labels_all, - 'groups_label': groups_labels_all, - BIAS_LABEL_KEY: bias_labels_all, - TRACIN_LABEL_KEY: tracin_labels_all, - 'bias': bias_values_all, - 'variance': vars_values_all, - 'noise': nois_values_all, - 'gap': gap_values_all, - 'error': error_values_all, - 'tracin_id': id_tracin_values_all, - 'tracin_ood': ood_tracin_values_all, - # Whether this example belongs to the training data. - 'is_train': is_train_all - }) - - csv_name = os.path.join( - save_dir, - 'bias_table.csv' if ckpt_epoch < 0 else f'bias_table_{ckpt_epoch}.csv') - df.to_csv(csv_name, index=False) - - init = tf.lookup.KeyValueTensorInitializer( - keys=tf.convert_to_tensor(example_ids_all, dtype=tf.string), - values=tf.convert_to_tensor(bias_labels_all, dtype=tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - return tf.lookup.StaticHashTable(init, default_value=0) - - -def get_example_id_to_predictions_table( - dataloader: data.Dataloader, - trained_models: List[tf.keras.Model], - has_bias: bool, - split: Optional[str] = 'train', - save_dir: Optional[str] = None, - save_table: Optional[bool] = True, - compute_tracin: Optional[bool] = False) -> pd.DataFrame: - """Generates a lookup table mapping example ID to bias label. - - Args: - dataloader: Dataclass object containing training and validation data. - trained_models: List of trained models. - has_bias: Do the trained models have a bias prediction head - split: Which split of the dataset to use ('train'/'val'/'test') - save_dir: Directory in which predictions table will be saved as CSV. - save_table: Boolean for whether or not to save table. - compute_tracin: Boolean whether or not to calculate the tracin values - with respect to the model predictions. - - Returns: - A pandas dataframe mapping example ID to all label and bias predictions. - """ - table_name = 'predictions_table' - ds = dataloader.train_ds - if split != 'train': - ds = dataloader.eval_ds[split] - table_name += '_' + split - labels = list( - ds.map( - lambda example: example['label']).as_numpy_iterator()) - labels = np.concatenate(labels) - predictions_all = [] - tracin_values_all = [] - if has_bias: - bias_predictions_all = [] - for idx, model in enumerate(trained_models): - model = trained_models[idx] - predictions = model.predict( - ds.map(lambda example: example['input_feature'])) - predictions_all.append(predictions['main'][..., 1]) - if has_bias: - bias_predictions_all.append(predictions['bias'][..., 1]) - if compute_tracin: - _, tracin_values, _ = calculate_tracin_values( - ds, [model], has_bias=has_bias, use_prediction_gradient=True - ) - tracin_values_all.append(tracin_values) - example_ids = list(ds.map( - lambda example: example['example_id']).as_numpy_iterator()) - example_ids = np.concatenate(example_ids) - predictions_all = np.stack(predictions_all) - if has_bias: - bias_predictions_all = np.stack(bias_predictions_all) - if compute_tracin: - tracin_values_all = np.stack(tracin_values_all) - - logging.info('# of examples in prediction table is: %s', example_ids.shape[0]) - - dict_values = {'example_id': example_ids} - for i in range(predictions_all.shape[0]): - dict_values[f'predictions_label_{i}'] = predictions_all[i] - if has_bias: - dict_values[f'predictions_bias_{i}'] = bias_predictions_all[i] - if compute_tracin: - dict_values[f'predictions_tracin_{i}'] = tracin_values_all[i] - df = pd.DataFrame(dict_values) - if save_table: - df.to_csv(os.path.join(save_dir, table_name + '.csv'), index=False) - return df - - -def get_example_id_to_tracin_value_table( - dataloader: data.Dataloader, - model_checkpoints: List[tf.keras.Model], - has_bias: bool, - split: Optional[str] = 'train', - included_layers: Optional[int] = -2, - save_dir: Optional[str] = None, - save_table: Optional[bool] = True, - table_name_suffix: Optional[str] = 'first', -) -> tf.lookup.StaticHashTable: - """Generates a lookup table mapping example ID to tracin value. - - Args: - dataloader: Dataclass object containing training and validation data. - model_checkpoints: List of model checkpoints. - has_bias: Do the trained models have a bias prediction head - split: Which split of the dataset to use ('train'/'val'/'test') - included_layers: Layers to include in Tracin computation (all trainable - layers from the index forward are included) - save_dir: Directory in which bias table will be saved as CSV. - save_table: Boolean for whether or not to save table. - table_name_suffix: String to add to the name of the created table - Returns: - A lookup table mapping example ID to tracin score. - """ - - ds = dataloader.train_ds - table_name = 'tracin_table' - if split != 'train': - ds = dataloader.eval_ds[split] - table_name += '_' + split - example_ids_all, tracin_values_all, probs_all = calculate_tracin_values( - ds, model_checkpoints, included_layers, has_bias) - logging.info('# of examples: %s', example_ids_all.shape[0]) - - if save_table: - df = pd.DataFrame({ - EXAMPLE_ID_KEY: example_ids_all, - TRACIN_SCORE_KEY: tracin_values_all, - PREDICTION_KEY: probs_all - }) - df.to_csv( - os.path.join(save_dir, 'tracin_table_'+ table_name_suffix + '.csv'), - index=False) - - init = tf.lookup.KeyValueTensorInitializer( - keys=tf.convert_to_tensor(example_ids_all, dtype=tf.string), - values=tf.convert_to_tensor(tracin_values_all, dtype=tf.float64), - key_dtype=tf.string, - value_dtype=tf.float64) - return tf.lookup.StaticHashTable(init, default_value=0) - - -def load_existing_tracin_table(path_to_table: str): - """Loads tracin table from file.""" - df = pd.read_csv(path_to_table) - key_tensor = np.array([eval(x).decode('UTF-8') for # pylint:disable=eval-used - x in df[EXAMPLE_ID_KEY].to_list()]) - init = tf.lookup.KeyValueTensorInitializer( - keys=tf.convert_to_tensor( - key_tensor, dtype=tf.string), - values=tf.convert_to_tensor( - df[TRACIN_SCORE_KEY].to_numpy(), dtype=tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - return tf.lookup.StaticHashTable(init, default_value=0) - - -def calculate_tracin_values( - dataset: tf.data.Dataset, - model_checkpoints: List[tf.keras.Model], - included_layers: Optional[int] = -2, - has_bias: Optional[bool] = False, - use_prediction_gradient: Optional[bool] = False -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Calculates the tracin values for a given dataset [1]. - - Reference: - [1]: Pruthi et al. Estimating Training Data Influence by Tracing Gradient - Descent. https://arxiv.org/abs/2002.08484 - - Args: - dataset: Dataset object containing data. - model_checkpoints: List of model checkpoints. - included_layers: Layers to include in Tracin computation (all trainable - layers from the index forward are included.) Default (-2) the last two - layers are included. - has_bias: Do the trained models have a bias prediction head. If yes, layers - to predrict bias are ignored. - use_prediction_gradient: Calculate the Tracin values for the loss with - respect to the predicted labels (instead of true labels) - instead of loss - - Returns: - An array of example_ids and a corresponding arrays of tracin values and - predictions. - - """ - example_ids_all = [] - tracin_values_all = [] - probs_all = [] - - included_layers_start = included_layers - included_layers_end = -1 - if has_bias: - included_layers_start -= 2 - included_layers_end = -2 - - @tf.function - def run_self_influence( - batch: Dict[str, tf.Tensor], - checkpoints: List[tf.keras.Model], - included_layers_start: int, - included_layers_end: int, - use_prediction_gradient: Optional[bool] = False - ) -> Tuple[tf.Tensor, Any, Any]: - example_ids = batch['example_id'] - features = batch['input_feature'] - labels = batch['label'] - self_influences = [] - probs_np = [] - for model in checkpoints: - with tf.GradientTape(watch_accessed_variables=False) as tape: - tape.watch( - model.trainable_weights[included_layers_start:included_layers_end]) - probs = model(features)['main'] - if use_prediction_gradient: - y_pred = tf.math.argmax(probs, axis=1) - loss = tf.keras.losses.sparse_categorical_crossentropy(y_pred, probs) - grads = tape.jacobian( - loss, - model.trainable_weights[ - included_layers_start:included_layers_end - ]) - else: - loss = tf.keras.losses.sparse_categorical_crossentropy(labels, probs) - grads = tape.jacobian( - loss, - model.trainable_weights[ - included_layers_start:included_layers_end - ]) - scores = tf.add_n( - [ - tf.math.reduce_sum( - grad * grad, axis=tf.range(1, tf.rank(grad), 1) - ) - for grad in grads - ] - ) - self_influences.append(scores) - probs_np.append(probs[:, 0]) - - return example_ids, tf.math.reduce_sum( - tf.stack(self_influences, axis=-1), axis=-1), tf.math.reduce_sum( - tf.stack(probs_np, axis=-1), axis=-1) - - inputshape = dataset.element_spec['input_feature'].shape - for model in model_checkpoints: - model.build(inputshape) - - logging.info('Checkpoints built.') - for batch in dataset: - example_ids, tracin_values, probs = run_self_influence( - batch, model_checkpoints, included_layers_start, included_layers_end, - use_prediction_gradient) - example_ids_all.append(example_ids) - tracin_values_all.append(tracin_values) - probs_all.append(probs) - - example_ids_all = np.concatenate(example_ids_all) - tracin_values_all = np.squeeze(np.concatenate(tracin_values_all)) - probs_all = np.concatenate(probs_all) - - return example_ids_all, tracin_values_all, probs_all - -# Helper functions to process hash tables - - -def filter_ids_fn(hash_table, value=1): - """Filter dataset based on whether ids take a certain value in hash table.""" - def filter_fn(examples): - return hash_table.lookup(examples['example_id']) == value - return filter_fn - - - diff --git a/experimental/shoshin/log_metrics_callback.py b/experimental/shoshin/log_metrics_callback.py deleted file mode 100644 index 5ebcb0f67..000000000 --- a/experimental/shoshin/log_metrics_callback.py +++ /dev/null @@ -1,182 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Keras callback for logging metrics to XManager. - -Adapted from google3/third_party/py/xmanager/examples/alphabet/codelab/log_metrics_callback.py. - -The `LogMetricsCallback` defined in this module is used in the XManager codelab -(go/xmanager-codelab). It is a simplified version of the `LogMetricsCallback` in -//research/kernel/auricle/segmentation/v_1/cross_segment_bert_training_lib.py. -""" - -import abc -from typing import Mapping, Optional, Sequence, Union - -import tensorflow as tf - -from google3.learning.deepmind.xmanager2.client import xmanager_api - -_ScalarMetric = Union[float, int] -_MetricDict = Mapping[str, _ScalarMetric] - - -class MetricLogger(abc.ABC): - """Abstract base class for logging metrics. - - `MetricLoggers` are typically used in conjunction with the - `LogMetricsCallback`. - """ - - @abc.abstractmethod - def log_scalar_metric( - self, metric_label: str, metric_value: _ScalarMetric, step: int, - is_val_metric: bool - ) -> None: - """Logs a metric name and value at the specified step. - - For example, to log the training accuracy at the end of the first epoch, one - might call this function as: - - `log_scalar_metric('epoch_accuracy', 0.89, examples_per_epoch, False)`. - - Args: - metric_label: The name of the metric being logged. Typically we assume the - 'val_' prefix for validation metrics has been removed from the metric - label prior to being passed to this function, and that a prefix of - 'epoch_' or 'batch_' has been added to the metric label to indicate if - this metric corresponds to an epoch or a batch. See the - `LogMetricsCallback` for more details. - metric_value: The value of the metric being logged. - step: The step number at which to log this metric. Metrics are normally - visualized on `metric_value` vs. `step` plots (e.g., on XManager or - TensorBoard). The `LogMetricsCallback` sets this value equal to the - number of training steps that have been seen up to the point the metric - is logged. - is_val_metric: A boolean specifying whether this metric was computed on a - validation set. - """ - - -class XManagerMetricLogger(MetricLogger): - """Class for logging metrics to XManager.""" - - def __init__(self, xmanager_work_unit: xmanager_api.WorkUnit) -> None: - self._work_unit = xmanager_work_unit - - def log_scalar_metric( - self, metric_label: str, metric_value: _ScalarMetric, step: int, - is_val_metric: bool - ) -> None: - xm_label = metric_label + '_val' if is_val_metric else metric_label - measurements = self._work_unit.get_measurement_series(label=xm_label) - measurements.create_measurement(metric_value, step=step) - - -class LogMetricsCallback(tf.keras.callbacks.Callback): - """A callback for logging metrics, for example to TensorBoard or XManager. - - During training, this callback will log all metrics after every training batch - where the total number of examples seen up to that point in the training epoch - is a multiple of the specified logging frequency, as well as at the end of - every epoch. This callback logs metrics by invoking the `log_scalar_metric` - function on all the metric loggers that are provided to the callback's - constructor. The metric loggers are objects, such as `XManagerMetricLogger` - and `TensorBoardMetricLogger`, which derive from `MetricLogger`. - """ - - def __init__( - self, - metric_loggers: Sequence[MetricLogger], - logging_frequency: int, - batch_size: int, - num_train_examples_per_epoch: int, - ) -> None: - """Initializes the `LogMetricsCallback`. - - Args: - metric_loggers: A list of `MetricLogger` objects that are invoked to log - training/validation metrics during the course of a Keras training run. - logging_frequency: How frequently, in terms of the number of training - examples seen during an epoch, to log metrics. For example, if - `logging_frequency` is 128, and batch_size is 64, then the metrics would - get logged every other batch. `logging_frequency` must be a multiple of - `batch_size`. - batch_size: The batch size used during training. - num_train_examples_per_epoch: The total number of training examples seen - during the course of an epoch. - """ - super().__init__() - if not metric_loggers: - raise ValueError('Must specify at least one MetricLogger.') - if logging_frequency % batch_size != 0: - raise ValueError( - 'logging_frequency must be a multiple of batch_size.' - ) - self._metric_loggers = metric_loggers - self._logging_frequency = logging_frequency - self._batch_size = batch_size - self._num_train_examples_per_epoch = num_train_examples_per_epoch - self._epoch = -1 - - def _log_metrics( - self, logs: _MetricDict, num_examples_seen: int, - metric_format_str: str, is_val_metric: bool = False, - ) -> None: - """Logs all metrics in 'logs' dictionary.""" - for metric_name, metric_value in logs.items(): - metric_label = metric_format_str.format(metric_name) - for metric_logger in self._metric_loggers: - metric_logger.log_scalar_metric( - metric_label, metric_value, num_examples_seen, is_val_metric, - ) - - def on_epoch_begin( - self, epoch: int, logs: Optional[_MetricDict] = None - ) -> None: - """Stores the epoch number at the beginning of every epoch.""" - self._epoch = epoch - - def on_epoch_end( - self, epoch: int, logs: Optional[_MetricDict] = None - ) -> None: - """Logs all metrics at the end of every epoch.""" - num_examples_seen = (epoch + 1) * self._num_train_examples_per_epoch - if logs: - # Separate the train vs. validation metrics in `logs`, and log them. - train_metrics = {metric_name: metric_value - for metric_name, metric_value in logs.items() - if not metric_name.startswith('val_')} - val_metrics = {metric_name.replace('val_', ''): metric_value - for metric_name, metric_value in logs.items() - if metric_name.startswith('val_')} - self._log_metrics(train_metrics, num_examples_seen, 'epoch_{}', - is_val_metric=False) - self._log_metrics(val_metrics, num_examples_seen, 'epoch_{}', - is_val_metric=True) - - def on_train_batch_end( - self, batch: int, logs: Optional[_MetricDict] = None - ) -> None: - """Logs all metrics after training batches at specified intervals.""" - num_examples_seen_in_prior_epochs = (self._epoch * - self._num_train_examples_per_epoch) - num_examples_seen_in_this_epoch = (batch + 1) * self._batch_size - num_examples_seen_total = (num_examples_seen_in_prior_epochs + - num_examples_seen_in_this_epoch) - if logs and num_examples_seen_in_this_epoch % self._logging_frequency == 0: - # Log all the metrics in the `logs` dictionary. - self._log_metrics(logs, num_examples_seen_total, 'batch_{}', - is_val_metric=False) diff --git a/experimental/shoshin/metrics.py b/experimental/shoshin/metrics.py deleted file mode 100644 index 2d1275dc2..000000000 --- a/experimental/shoshin/metrics.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Custom evaluation metrics.""" - -import tensorflow as tf - - -@tf.keras.saving.register_keras_serializable('one_vs_rest') -class OneVsRest(tf.keras.metrics.Metric): - """A wrapper that extends metrics from binary to multi-class setup. - - This extension is done using one-vs-rest strategy. In which one class - will be treated as a positive and the rest classes will be negative, - hence reducing the calculation to a binary setup. - """ - - def __init__( - self, metric: tf.keras.metrics.Metric, positive_class_index: int, **kwargs - ): - """Constructor. - - Args: - metric: A metric object that is need to be extended for multi-class setup. - positive_class_index: The index of the positive class. - **kwargs: Keyword arguments expected by the core metric. - """ - super(OneVsRest, self).__init__(name=metric.name, **kwargs) - self.metric = metric - self.positive_class_index = positive_class_index - - def update_state(self, y_true, y_pred, **kwargs): - """Accumulate metrics statistics. - - Args: - y_true: The ground truth labels. An integer tensor of shape (num_examples, - num_classes). - y_pred: The predicted values. A float tensor of shape (num_examples, - num_classes). - **kwargs: Keyword arguments expected by the core metric. - """ - y_pred = y_pred[..., self.positive_class_index] - y_true = y_true[..., self.positive_class_index] - self.metric.update_state(y_true, y_pred, **kwargs) - - def result(self): - return self.metric.result() - - def reset_state(self): - self.metric.reset_state() - - def get_config(self): - return { - 'metric': self.metric, - 'positive_class_index': self.positive_class_index - } diff --git a/experimental/shoshin/metrics_test.py b/experimental/shoshin/metrics_test.py deleted file mode 100644 index 41d263924..000000000 --- a/experimental/shoshin/metrics_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for metrics.""" - -import tensorflow as tf -import metrics # local file import from experimental.shoshin - -from google3.testing.pybase import googletest - - -class MetricsTest(tf.test.TestCase): - - def test_one_vs_rest_auc(self): - auc = tf.keras.metrics.AUC() - one_vs_rest_auc = metrics.OneVsRest(tf.keras.metrics.AUC(), 1) - - y_true = tf.constant([1, 0, 1, 1], dtype=tf.int32) - one_hot_y_true = tf.one_hot(y_true, depth=2) - y_pred = tf.constant([0.9, 0.3, 0.7, 0.2], dtype=tf.float32) - one_hot_y_pred = tf.constant( - [[0.1, 0.9], [0.7, 0.3], [0.3, 0.7], [0.8, 0.2]], dtype=tf.float32 - ) - - auc.update_state(y_true, y_pred) - expected_result = auc.result() - - one_vs_rest_auc.update_state(one_hot_y_true, one_hot_y_pred) - result = one_vs_rest_auc.result() - self.assertAllClose(result, expected_result) - - def test_one_vs_rest_aucpr(self): - auc = tf.keras.metrics.AUC(curve="PR") - one_vs_rest_auc = metrics.OneVsRest(tf.keras.metrics.AUC(curve="PR"), 1) - - y_true = tf.constant([1, 0, 1, 1], dtype=tf.int32) - one_hot_y_true = tf.one_hot(y_true, depth=2) - y_pred = tf.constant([0.9, 0.3, 0.7, 0.2], dtype=tf.float32) - one_hot_y_pred = tf.constant( - [[0.1, 0.9], [0.7, 0.3], [0.3, 0.7], [0.8, 0.2]], dtype=tf.float32 - ) - - auc.update_state(y_true, y_pred) - expected_result = auc.result() - - one_vs_rest_auc.update_state(one_hot_y_true, one_hot_y_pred) - result = one_vs_rest_auc.result() - self.assertAllClose(result, expected_result) - -if __name__ == "__main__": - googletest.main() diff --git a/experimental/shoshin/models.py b/experimental/shoshin/models.py deleted file mode 100644 index 26da9a5f1..000000000 --- a/experimental/shoshin/models.py +++ /dev/null @@ -1,304 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Library of models to use in Introspective Active Sampling. - -This file contains a library of models that have two output heads: one for the -main training task and an optional second for bias. Any of these models can -serve as the base model trained in Introspective Active Sampling. -""" - -import dataclasses -from typing import Dict, Optional - -import tensorflow as tf - - -MODEL_REGISTRY = {} -RESNET_IMAGE_SIZE = 224 - - -def register_model(name: str): - """Provides decorator to register model classes.""" - def save(model_class): - MODEL_REGISTRY[name] = model_class - return model_class - - return save - - -def get_model(name: str): - """Retrieves dataset based on name.""" - if name not in MODEL_REGISTRY: - raise ValueError( - f'Unknown model: {name}\nPossible choices: {MODEL_REGISTRY.keys()}') - return MODEL_REGISTRY[name] - - -@dataclasses.dataclass -class ModelTrainingParameters: - """Dataclass for training parameters.""" - model_name: str - train_bias: bool - num_classes: int - num_subgroups: int - subgroup_sizes: Dict[str, int] - num_epochs: int - num_channels: int = 3 - l2_regularization_factor: float = 0.5 - optimizer: str = 'sgd' - learning_rate: float = 1e-5 - batch_size: int = 64 - load_pretrained_weights: Optional[bool] = False - worst_group_label: Optional[int] = 2 - use_pytorch_style_resnet: Optional[bool] = False - do_reweighting: Optional[bool] = False - reweighting_signal: Optional[str] = 'bias' - reweighting_lambda: Optional[float] = 0.5 - reweighting_error_percentile_threshold: Optional[float] = 0.2 - - def asdict(self): - return dataclasses.asdict(self) - - @classmethod - def from_dict(cls, kwargs): - return ModelTrainingParameters(**kwargs) - - -@register_model('resnet50v1') -@tf.keras.saving.register_keras_serializable('resnet50v1') -class ResNet50v1(tf.keras.Model): - """Defines a ResNet50v1 model class with two output heads. - - One output head is for the main training task, while the other is an optional - head to train on bias labels. Inputs are feature vectors. - """ - - def __init__(self, - model_params: ModelTrainingParameters): - super(ResNet50v1, self).__init__(name=model_params.model_name) - - self.model_params = model_params - self.resnet_model = tf.keras.applications.resnet50.ResNet50( - include_top=False, - weights='imagenet' if model_params.load_pretrained_weights else None, - input_shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, - model_params.num_channels), - classes=model_params.num_classes, - pooling='avg' - # TODO(jihyeonlee): Consider making pooling method a flag. - ) - - regularizer = tf.keras.regularizers.L2( - l2=model_params.l2_regularization_factor) - for layer in self.resnet_model.layers: - layer.trainable = True - if model_params.use_pytorch_style_resnet: - if hasattr(layer, 'kernel_regularizer'): - setattr(layer, 'kernel_regularizer', regularizer) - if isinstance(layer, tf.keras.layers.Conv2D): - layer.use_bias = False - initializer = tf.keras.initializers.HeNormal() - layer.kernel_initializer = initializer - if isinstance(layer, tf.keras.layers.BatchNormalization): - layer.momentum = 0.9 - - self.output_main = tf.keras.layers.Dense( - model_params.num_classes, - activation='softmax', - name='main', - kernel_regularizer=regularizer) - - self.output_bias = tf.keras.layers.Dense( - model_params.num_classes, - trainable=model_params.train_bias, - activation='softmax', - name='bias', - kernel_regularizer=regularizer) - - def get_config(self): - config = super(ResNet50v1, self).get_config() - config.update({'model_params': self.model_params.asdict(), - 'resnet_model': self.resnet_model, - 'output_main': self.output_main, - 'output_bias': self.output_bias}) - return config - - @classmethod - def from_config(cls, config): - return cls(ModelTrainingParameters.from_dict(config['model_params'])) - - def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - x = self.resnet_model(inputs['large_image']) - out_main = self.output_main(x) - out_bias = self.output_bias(x) - return { - 'main': out_main, - 'bias': out_bias - } - - -@register_model('resnet50v2') -@tf.keras.saving.register_keras_serializable('resnet50v2') -class ResNet50v2(tf.keras.Model): - """Defines a ResNet50v2 model class with two output heads. - - One output head is for the main training task, while the other is an optional - head to train on bias labels. Inputs are feature vectors. - """ - - def __init__(self, - model_params: ModelTrainingParameters): - super(ResNet50v2, self).__init__(name=model_params.model_name) - - self.model_params = model_params - self.resnet_model = tf.keras.applications.resnet_v2.ResNet50V2( - include_top=False, - weights='imagenet' if model_params.load_pretrained_weights else None, - input_shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, - model_params.num_channels), - classes=model_params.num_classes, - pooling='avg' - # TODO(jihyeonlee): Consider making pooling method a flag. - ) - - regularizer = tf.keras.regularizers.L2( - l2=model_params.l2_regularization_factor) - for layer in self.resnet_model.layers: - layer.trainable = True - if model_params.use_pytorch_style_resnet: - if hasattr(layer, 'kernel_regularizer'): - setattr(layer, 'kernel_regularizer', regularizer) - if isinstance(layer, tf.keras.layers.Conv2D): - layer.use_bias = False - initializer = tf.keras.initializers.HeNormal() - layer.kernel_initializer = initializer - if isinstance(layer, tf.keras.layers.BatchNormalization): - layer.momentum = 0.9 - - self.output_main = tf.keras.layers.Dense( - model_params.num_classes, - activation='softmax', - name='main', - kernel_regularizer=regularizer) - - self.output_bias = tf.keras.layers.Dense( - model_params.num_classes, - trainable=model_params.train_bias, - activation='softmax', - name='bias', - kernel_regularizer=regularizer) - - def get_config(self): - config = super(ResNet50v2, self).get_config() - config.update({'model_params': self.model_params.asdict(), - 'resnet_model': self.resnet_model, - 'output_main': self.output_main, - 'output_bias': self.output_bias}) - return config - - @classmethod - def from_config(cls, config): - return cls(ModelTrainingParameters.from_dict(config['model_params'])) - - def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - x = self.resnet_model(inputs['large_image']) - out_main = self.output_main(x) - out_bias = self.output_bias(x) - return { - 'main': out_main, - 'bias': out_bias - } - - -@register_model('two_tower') -@tf.keras.saving.register_keras_serializable('two_tower') -class TwoTower(tf.keras.Model): - """Defines Two Tower class with two output heads. - - One output head is for the main training task, while the other is an optional - head to train on bias labels. Inputs are feature vectors. - """ - - def __init__(self, - model_params: ModelTrainingParameters): - super(TwoTower, self).__init__(name=model_params.model_name) - - self.model_params = model_params - backbone = tf.keras.applications.resnet_v2.ResNet50V2( - include_top=False, - weights='imagenet' if model_params.load_pretrained_weights else None, - input_shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, - 3), - classes=model_params.num_classes, - pooling='avg' - # TODO(jihyeonlee): Consider making pooling method a flag. - ) - - if model_params.load_pretrained_weights: - backbone.trainable = False - - dense = tf.keras.Sequential([ - # TODO(melfatih): Add a hyperparameter for dropout. - tf.keras.layers.Dropout(0.5), - # TODO(melfatih): Add a hyperparameter for embedding size. - tf.keras.layers.Dense(units=256, activation='relu'), - tf.keras.layers.Dropout(0.5), - tf.keras.layers.Dense(units=64, activation='relu'), - tf.keras.layers.Dropout(0.5), - ]) - self.backbone = tf.keras.Sequential([backbone, dense]) - self.output_main = ( - tf.keras.layers.Dense( - units=model_params.num_classes, activation='sigmoid' - ) - ) - self.output_bias = tf.keras.layers.Dense( - model_params.num_classes, - trainable=model_params.train_bias, - activation='softmax', - name='bias', - ) - - def get_config(self): - config = super(TwoTower, self).get_config() - config.update({'model_params': self.model_params.asdict(), - 'backbone': self.backbone, - 'output_main': self.output_main, - 'output_bias': self.output_bias}) - return config - - @classmethod - def from_config(cls, config): - return cls(ModelTrainingParameters.from_dict(config['model_params'])) - - def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - large_image, small_image = inputs['large_image'], inputs['small_image'] - - if self.model_params.num_channels == 3: - after_embed = self.backbone(large_image) - after_crop_embed = self.backbone(small_image) - combined = tf.concat([after_embed, after_crop_embed], axis=-1) - elif self.model_params.num_channels == 6: - after_embed = self.backbone(large_image[:, :, :, 3:]) - after_crop_embed = self.backbone(small_image[:, :, :, 3:]) - before_embed = self.backbone(large_image[:, :, :, :3]) - combined = tf.concat( - [before_embed, after_embed, after_crop_embed], axis=-1 - ) - - out_main = self.output_main(combined) - out_bias = self.output_bias(combined) - return {'main': out_main, 'bias': out_bias} diff --git a/experimental/shoshin/read_predictions.py b/experimental/shoshin/read_predictions.py deleted file mode 100644 index 644b39892..000000000 --- a/experimental/shoshin/read_predictions.py +++ /dev/null @@ -1,38 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Read predictions from experiment and compute bias label.""" - -import pandas as pd - -from google3.learning.deepmind.researchdata import datatables - - -def read_predictions(xid): - """Read predictions from experiment.""" - reader = datatables.Reader(f'/datatable/xid/{xid}/predictions') - df = reader.read() - df = pd.DataFrame(df, columns=df.keys()) - df_agg = df.groupby(by=['id', 'in_sample']).agg('mean').reset_index() - df_insample = df_agg.loc[df_agg['in_sample'], :] - df_outsample = df_agg.loc[~df_agg['in_sample'], :] - df_result = pd.merge(df_insample, df_outsample, on=['id']) - df_result = df_result.loc[:, ['id', 'prediction_x', 'prediction_y']] - df_result = df_result.rename(columns={ - 'prediction_x': 'prediction_insample', - 'prediction_y': 'prediction_outsample' - }) - df_result = df_result.set_index('id') - return df_result diff --git a/experimental/shoshin/sample_ids.py b/experimental/shoshin/sample_ids.py deleted file mode 100644 index 21cb5ef3a..000000000 --- a/experimental/shoshin/sample_ids.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Binary executable for generating ids to sample in next round. - -This file serves as a binary to compute the ids of samples to be included in -next round of training in an active learning loop. - - -Note: In output_dir, models trained on different splits of data must already -exist and be present in directory. -""" - -import os - -from absl import app -from absl import flags -from ml_collections import config_flags -import pandas as pd -import tensorflow as tf -import sampling_policies # local file import from experimental.shoshin -from configs import base_config # local file import from experimental.shoshin - - -FLAGS = flags.FLAGS -config_flags.DEFINE_config_file('config') - - -def main(_) -> None: - - config = FLAGS.config - base_config.check_flags(config) - bias_table = pd.read_csv(os.path.join(config.output_dir, 'bias_table.csv')) - predictions_table = pd.read_csv(os.path.join(config.output_dir, - 'predictions_table.csv')) - tf.io.gfile.makedirs(config.ids_dir) - _ = sampling_policies.sample_and_split_ids( - bias_table['example_id'].to_numpy(), - predictions_table, - config.active_sampling.sampling_score, - config.active_sampling.num_samples_per_round, - config.data.num_splits, - config.ids_dir, - True) - -if __name__ == '__main__': - app.run(main) diff --git a/experimental/shoshin/sampling_policies.py b/experimental/shoshin/sampling_policies.py deleted file mode 100644 index 0eb5b7984..000000000 --- a/experimental/shoshin/sampling_policies.py +++ /dev/null @@ -1,143 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for Introspective Active Sampling. - -Library of utilities for the Introspecive Active Sampling method. Includes a -function to generate a table mapping example ID to bias label, which can be -used to train the bias output head. -""" - -import os -from typing import List - -import numpy as np -import pandas as pd -import tensorflow as tf - - -def compute_ids_to_sample( - sampling_score: str, - predictions_df: pd.DataFrame, - num_samples: int,) -> List[str]: - """Compute ids to actively sample new labels for. - - Args: - sampling_score: Which score to use for sampling. Currently supported - options are 'ensemble_uncertainty', 'ensemble_margin', 'bias', 'tracin', - 'random' - predictions_df: Dataframe with columns `example_ids` and - `predictions_label_{i}` `predictions_bias_{i}` `predictions_tracin_{k}` - (only if predicted tracin scores are used) for i in range(k) - num_samples: Number of samples - - Returns: - List of example ids to sample based on sampling score - """ - prediction_label_cols = filter(lambda x: 'label' in x, predictions_df.columns) - prediction_bias_cols = filter(lambda x: 'bias' in x, predictions_df.columns) - prediction_tracin_cols = filter( - lambda x: 'tracin' in x, predictions_df.columns - ) - if sampling_score == 'ensemble_uncertainty': - sample_avg = predictions_df[prediction_label_cols].mean(axis=1).to_numpy() - uncertainty = np.abs(sample_avg - .5) - predictions_df['sampling_score'] = uncertainty - elif sampling_score == 'ensemble_variance': - sample_std = predictions_df[prediction_label_cols].std(axis=1).to_numpy() - predictions_df['sampling_score'] = 1 - sample_std - elif sampling_score == 'bias': - sample_avg = predictions_df[prediction_bias_cols].mean(axis=1).to_numpy() - predictions_df['sampling_score'] = 1 - sample_avg - elif sampling_score == 'tracin': - sample_avg = predictions_df[prediction_tracin_cols].mean(axis=1).to_numpy() - predictions_df['sampling_score'] = 1 - sample_avg - elif sampling_score == 'random': - sample_avg = predictions_df[prediction_tracin_cols].mean(axis=1).to_numpy() - predictions_df['sampling_score'] = np.random.random(size=sample_avg.shape) - predictions_df = predictions_df.sort_values( - by='sampling_score', ascending=True) - return predictions_df.head(num_samples)['example_id'].to_numpy() - - -def sample_and_split_ids( - ids_train: List[str], - predictions_df: pd.DataFrame, - sampling_score: str, - num_samples_per_round: int, - num_splits: int, - save_dir: str, - save_ids: bool, - ) -> List[pd.DataFrame]: - """Computes ids to sample for next round and generates new training splits. - - Args: - ids_train: ids of examples used for training so far - predictions_df: A dataframe containing the predictions of the two-head - models for all the training samples. - sampling_score: The score used to rank candidates for active learning. - num_samples_per_round: Number of new samples to add in each round of - active learning. - num_splits: Number of splits to generate after active sampling. - save_dir: The director where the splits are to be saved - save_ids: A boolean indicating whether to save the ids - Returns: - A list of pandas dataframes, each containing a list of example ids to be - included in a split for the next round of training. - """ - predictions_df = predictions_df[~predictions_df['example_id'].isin(ids_train)] - ids_to_sample = compute_ids_to_sample( - sampling_score, predictions_df, - num_samples_per_round) - ids_to_sample = np.concatenate([ids_to_sample, ids_train], axis=0) - tf.io.gfile.makedirs(save_dir) - - # Randomly permute and split set of ids to sample - n_sample = ids_to_sample.size - order = np.random.permutation(n_sample) - split_idx = 0 - num_data_per_split = int(n_sample / num_splits) - split_dfs = [] - for i in range(num_splits): - ids_i = ids_to_sample[order[split_idx:min(split_idx + num_data_per_split, - n_sample - 1)]] - split_idx += ids_i.size - df = pd.DataFrame({'example_id': ids_i}) - split_dfs.append(df) - if save_ids: - df.to_csv( - os.path.join(save_dir, f'ids_{i}.csv'), - index=False) - return split_dfs - - -def convert_ids_to_table( - ids_dir: str,) -> List[tf.lookup.StaticHashTable]: - """Gets static hash table representing ids in each file in ids_dir.""" - ids_tables = [] - - # ids_dir is populated by the sample_and_split_ids function above - for ids_file in tf.io.gfile.listdir(ids_dir): - ids_i = pd.read_csv(os.path.join(ids_dir, ids_file))['example_id'] - ids_i = np.array([eval(x).decode('UTF-8') for x in ids_i.to_list()]) # pylint:disable=eval-used - keys = tf.convert_to_tensor(ids_i, dtype=tf.string) - values = tf.ones(shape=keys.shape, dtype=tf.int64) - init = tf.lookup.KeyValueTensorInitializer( - keys=keys, - values=values, - key_dtype=tf.string, - value_dtype=tf.int64) - ids_tables.append(tf.lookup.StaticHashTable(init, default_value=0)) - return ids_tables diff --git a/experimental/shoshin/train_tf.py b/experimental/shoshin/train_tf.py deleted file mode 100644 index 0747cb5e7..000000000 --- a/experimental/shoshin/train_tf.py +++ /dev/null @@ -1,208 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Binary to run training on a single model once. - - -# pylint: enable=line-too-long -""" - -import logging as native_logging -import os - -from absl import app -from absl import flags -from absl import logging -from ml_collections import config_flags -import pandas as pd -import tensorflow as tf -import data # local file import from experimental.shoshin -import generate_bias_table_lib # local file import from experimental.shoshin -import models # local file import from experimental.shoshin -import sampling_policies # local file import from experimental.shoshin -import train_tf_lib # local file import from experimental.shoshin -from configs import base_config # local file import from experimental.shoshin - - -FLAGS = flags.FLAGS -config_flags.DEFINE_config_file('config') -flags.DEFINE_bool('keep_logs', True, 'If True, creates a log file in output ' - 'directory. If False, only logs to console.') -flags.DEFINE_string('ensemble_dir', '', 'If specified, loads the models at ' - 'this directory to consider the ensemble.') - - -def main(_) -> None: - config = FLAGS.config - base_config.check_flags(config) - - if FLAGS.keep_logs and not config.training.log_to_xm: - if not tf.io.gfile.exists(config.output_dir): - tf.io.gfile.makedirs(config.output_dir) - stream = tf.io.gfile.GFile( - os.path.join(config.output_dir, 'log'), mode='w' - ) - stream_handler = native_logging.StreamHandler(stream) - logging.get_absl_logger().addHandler(stream_handler) - - dataset_builder = data.get_dataset(config.data.name) - ds_kwargs = {} - if config.data.name == 'waterbirds10k': - ds_kwargs = {'corr_strength': config.data.corr_strength} - elif config.data.name == 'skai': - ds_kwargs.update({ - 'tfds_dataset_name': config.data.tfds_dataset_name, - 'data_dir': config.data.tfds_data_dir, - 'labeled_train_pattern': config.data.labeled_train_pattern, - 'unlabeled_train_pattern': config.data.unlabeled_train_pattern, - 'validation_pattern': config.data.validation_pattern, - 'use_post_disaster_only': config.data.use_post_disaster_only, - 'load_small_images': config.data.load_small_images, - }) - if config.data.use_post_disaster_only: - config.model.num_channels = 3 - if config.upsampling.do_upsampling: - ds_kwargs.update({ - 'upsampling_lambda': config.upsampling.lambda_value, - 'upsampling_signal': config.upsampling.signal, - }) - - logging.info('Running Round %d of Training.', config.round_idx) - get_split_config = lambda x: x if config.data.use_splits else 1 - if config.round_idx == 0: - dataloader = dataset_builder( - num_splits=get_split_config(config.data.num_splits), - initial_sample_proportion=get_split_config( - config.data.initial_sample_proportion), - subgroup_ids=config.data.subgroup_ids, - subgroup_proportions=config.data.subgroup_proportions, **ds_kwargs) - else: - # If latter round, keep track of split generated in last round of active - # sampling - dataloader = dataset_builder(config.data.num_splits, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **ds_kwargs) - - # Filter each split to only have examples from example_ids_table - dataloader.train_splits = [ - dataloader.train_ds.filter( - generate_bias_table_lib.filter_ids_fn(ids_tab)) for - ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)] - - model_params = models.ModelTrainingParameters( - model_name=config.model.name, - train_bias=config.train_bias, - num_classes=config.data.num_classes, - num_subgroups=dataloader.num_subgroups, - subgroup_sizes=dataloader.subgroup_sizes, - worst_group_label=dataloader.worst_group_label, - num_epochs=config.training.num_epochs, - num_channels=config.model.num_channels, - l2_regularization_factor=config.model.l2_regularization_factor, - optimizer=config.optimizer.type, - learning_rate=config.optimizer.learning_rate, - batch_size=config.data.batch_size, - load_pretrained_weights=config.model.load_pretrained_weights, - use_pytorch_style_resnet=config.model.use_pytorch_style_resnet, - do_reweighting=config.reweighting.do_reweighting, - reweighting_lambda=config.reweighting.lambda_value, - reweighting_signal=config.reweighting.signal - ) - model_params.train_bias = config.train_bias - output_dir = config.output_dir - - tf.io.gfile.makedirs(output_dir) - example_id_to_bias_table = None - - if config.train_bias or (config.reweighting.do_reweighting and - config.reweighting.signal == 'bias'): - # Bias head will be trained as well, so gets bias labels. - if config.path_to_existing_bias_table: - example_id_to_bias_table = ( - generate_bias_table_lib.load_existing_bias_table( - config.path_to_existing_bias_table, - config.bias_head_prediction_signal, - ) - ) - else: - logging.info( - 'Error: Bias table not found') - return - if config.data.use_splits: - # Training a single model on a combination of data splits. - included_splits_idx = [int(i) for i in config.data.included_splits_idx] - new_train_ds = data.gather_data_splits(included_splits_idx, - dataloader.train_splits) - val_ds = data.gather_data_splits(included_splits_idx, dataloader.val_splits) - elif config.data.use_filtering: - # Use filter tables to generate subsets. - # This allows a better control over the number of trained models that. - # The number of models is independent of the odd ratio. E.g., 10 splits with - # an odd ratio 0f 0.5 trains 252 models and with an ood ratio of 0.1 only - # 10. Using filitering we can train 50 models for both of these ood ratios. - new_train_ds = data.filter_set( - dataloader=dataloader, - initial_sample_proportion=config.data.initial_sample_proportion, - initial_sample_seed=config.data.initial_sample_seed, - split_proportion=config.data.split_proportion, - split_id=config.data.split_id, - split_seed=config.data.split_seed, - training=True - ) - val_ds = data.filter_set( - dataloader=dataloader, - initial_sample_proportion=config.data.initial_sample_proportion, - initial_sample_seed=config.data.initial_sample_seed, - split_proportion=config.data.split_proportion, - split_id=config.data.split_id, - split_seed=config.data.split_seed, - training=False - ) - else: - raise ValueError( - 'In `config.data`, one of `(use_splits, use_filtering)` must be True.') - - dataloader.train_ds = new_train_ds - dataloader.eval_ds['val'] = val_ds - experiment_name = 'stage_2' if config.train_bias else 'stage_1' - - if config.save_train_ids: - table_name = 'training_ids_table' - ids = data.get_ids_from_dataset(dataloader.train_ds) - dict_values = {'example_id': ids} - df = pd.DataFrame(dict_values) - df.to_csv(os.path.join(output_dir, table_name + '.csv'), index=False) - # Apply batching (must apply batching only after filtering) - dataloader = data.apply_batch(dataloader, config.data.batch_size) - - _ = train_tf_lib.train_and_evaluate( - train_as_ensemble=config.train_stage_2_as_ensemble, - dataloader=dataloader, - model_params=model_params, - num_splits=config.data.num_splits, - ood_ratio=config.data.ood_ratio, - output_dir=output_dir, - experiment_name=experiment_name, - save_model_checkpoints=config.training.save_model_checkpoints, - save_best_model=config.training.save_best_model, - early_stopping=config.training.early_stopping, - ensemble_dir=FLAGS.ensemble_dir, - example_id_to_bias_table=example_id_to_bias_table) - - -if __name__ == '__main__': - app.run(main) diff --git a/experimental/shoshin/train_tf_lib.py b/experimental/shoshin/train_tf_lib.py deleted file mode 100644 index 86274d76e..000000000 --- a/experimental/shoshin/train_tf_lib.py +++ /dev/null @@ -1,923 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Training pipeline for a two-headed output model, where one head is for bias. - -Includes the model definition, which implements two-headed output using -custom losses and allows for any base model. Also provides training pipeline, -starting from compiling and initializing the model, fitting on training data, -and evaluating on provided eval datasets. -""" - -import itertools -import os -from typing import Dict, List, Optional, Union - -from absl import logging -import numpy as np -import tensorflow as tf -import data # local file import from experimental.shoshin -import models # local file import from experimental.shoshin - - - -@tf.keras.saving.register_keras_serializable('two_headed_output_model') -class TwoHeadedOutputModel(tf.keras.Model): - """Defines a two-headed output model.""" - - def __init__(self, - model: tf.keras.Model, - num_subgroups: int, - subgroup_sizes: Dict[str, int], - train_bias: bool, - name: str, - worst_group_label: Optional[int] = 2, - do_reweighting: Optional[bool] = False, - reweighting_signal: Optional[str] = 'bias', - reweighting_lambda: Optional[float] = 0.5, - error_percentile_threshold: Optional[float] = 0.2, - num_classes: Optional[int] = 2): - super(TwoHeadedOutputModel, self).__init__(name=name) - self.num_classes = num_classes - self.train_bias = train_bias - if self.train_bias or do_reweighting: - self.id_to_bias_table = None - - self.do_reweighting = do_reweighting - if do_reweighting: - self.reweighting_signal = reweighting_signal - self.reweighting_lambda = reweighting_lambda - if self.reweighting_signal == 'error': - self.error_percentile_threshold = error_percentile_threshold - - self.model = model - self.num_subgroups = num_subgroups - if self.num_subgroups > 1: - self.avg_acc = tf.keras.metrics.Mean(name='avg_acc') - self.weighted_avg_acc = tf.keras.metrics.Sum(name='weighted_avg_acc') - self.subgroup_sizes = subgroup_sizes - self.worst_group_label = worst_group_label - - def get_config(self): - config = super().get_config() - config.update({ - 'model': self.model, - 'num_subgroups': self.num_subgroups, - 'subgroup_sizes': self.subgroup_sizes, - 'train_bias': self.train_bias, - 'name': self.name, - 'worst_group_label': self.worst_group_label - }) - return config - - def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - return self.model(inputs) - - def update_id_to_bias_table(self, table): - self.id_to_bias_table = table - - def _compute_average_metrics( - self, metrics: List[tf.keras.metrics.Metric] - ) -> Dict[str, tf.keras.metrics.Metric]: - """Computes metrics as an average or weighted average of all subgroups. - - For the weighted metric, the subgroups are weighed by their proportionality. - - Args: - metrics: List of metrics to be parsed. - - Returns: - Dictionary mapping metric name to result. - """ - accs = [] - total_size = sum(self.subgroup_sizes.values()) - weighted_accs = [] - for m in metrics: - if 'subgroup' in m.name and 'main' in m.name: - accs.append(m.result()) - subgroup_label = m.name.split('_')[1] - weighted_accs.append( - m.result() * float(self.subgroup_sizes[subgroup_label]) / total_size - ) - self.avg_acc.reset_state() - self.avg_acc.update_state(accs) - self.weighted_avg_acc.reset_state() - self.weighted_avg_acc.update_state(weighted_accs) - return { - self.avg_acc.name: self.avg_acc.result(), - self.weighted_avg_acc.name: self.weighted_avg_acc.result(), - } - - def train_step(self, inputs): - features = inputs['input_feature'] - labels = inputs['label'] - example_ids = inputs['example_id'] - subgroup_labels = inputs['subgroup_label'] - - y_true_main = tf.one_hot(labels, depth=self.num_classes) - - with tf.GradientTape() as tape: - y_pred = self(features, training=True) - - y_true = {'main': y_true_main} - if self.train_bias or (self.do_reweighting and - self.reweighting_signal == 'bias'): - if self.id_to_bias_table is None: - raise ValueError('id_to_bias_table must not be None.') - y_true_bias = self.id_to_bias_table.lookup(example_ids) - y_true_bias_original = y_true_bias - y_true_bias = tf.one_hot(y_true_bias, depth=2) - y_true['bias'] = y_true_bias - - sample_weight = None - if self.do_reweighting: - if self.reweighting_signal == 'bias': - # Loads bias label from table, which has already been determined by - # threshold. - reweighting_labels = y_true_bias_original - elif self.reweighting_signal == 'error': # Use prediction error. - error = tf.math.subtract( - tf.ones_like(y_pred), tf.gather_nd(y_pred, y_true_main)) - threshold = np.percentile(error, self.error_percentile_threshold) - reweighting_labels = tf.math.greater(error, threshold) - else: # Give weight to worst group only. - reweighting_labels = tf.math.equal(subgroup_labels, - self.worst_group_label) - - above_threshold_example_multiplex = tf.math.multiply( - self.reweighting_lambda, - tf.ones_like(reweighting_labels, dtype=tf.float32)) - below_threshold_example_multiplex = tf.math.multiply( - 1. - self.reweighting_lambda, - tf.ones_like(reweighting_labels, dtype=tf.float32)) - sample_weight = tf.where( - reweighting_labels, - above_threshold_example_multiplex, - below_threshold_example_multiplex) - - total_loss = self.compiled_loss( - y_true, y_pred, sample_weight=sample_weight) - total_loss += sum(self.losses) # Regularization loss. - - gradients = tape.gradient(total_loss, self.model.trainable_variables) - self.optimizer.apply_gradients( - zip(gradients, self.model.trainable_variables)) - - for i in range(self.num_subgroups): - subgroup_idx = tf.where(tf.math.equal(subgroup_labels, i)) - subgroup_pred = tf.gather(y_pred['main'], subgroup_idx, axis=0) - - subgroup_true = tf.gather(y_true['main'], subgroup_idx, axis=0) - y_true['_'.join(['subgroup', str(i), 'main'])] = subgroup_true - y_pred['_'.join(['subgroup', str(i), 'main'])] = subgroup_pred - if self.train_bias: - subgroup_pred = tf.gather(y_pred['bias'], subgroup_idx, axis=0) - subgroup_true = tf.gather(y_true['bias'], subgroup_idx, axis=0) - y_true['_'.join(['subgroup', str(i), 'bias'])] = subgroup_true - y_pred['_'.join(['subgroup', str(i), 'bias'])] = subgroup_pred - - self.compiled_metrics.update_state(y_true, y_pred) - results = {m.name: m.result() for m in self.metrics} - if self.num_subgroups > 1: - results.update(self._compute_average_metrics(self.metrics)) - - return results - - def test_step(self, inputs): - features = inputs['input_feature'] - labels = inputs['label'] - example_ids = inputs['example_id'] - subgroup_labels = inputs['subgroup_label'] - y_true_main = tf.one_hot(labels, depth=2) - y_pred = self(features, training=False) - y_true = {'main': y_true_main} - if self.train_bias: - if self.id_to_bias_table is None: - raise ValueError('id_to_bias_table must not be None.') - y_true_bias = self.id_to_bias_table.lookup(example_ids) - y_true['bias'] = tf.one_hot(y_true_bias, depth=2) - - for i in range(self.num_subgroups): - subgroup_idx = tf.where(tf.math.equal(subgroup_labels, i)) - subgroup_pred = tf.gather(y_pred['main'], subgroup_idx, axis=0) - subgroup_true = tf.gather(y_true['main'], subgroup_idx, axis=0) - y_true['_'.join(['subgroup', str(i), 'main'])] = subgroup_true - y_pred['_'.join(['subgroup', str(i), 'main'])] = subgroup_pred - if self.train_bias: - subgroup_pred = tf.gather(y_pred['bias'], subgroup_idx, axis=0) - subgroup_true = tf.gather(y_true['bias'], subgroup_idx, axis=0) - y_true['_'.join(['subgroup', str(i), 'bias'])] = subgroup_true - y_pred['_'.join(['subgroup', str(i), 'bias'])] = subgroup_pred - - self.compiled_metrics.update_state(y_true, y_pred) - results = {m.name: m.result() for m in self.metrics} - if self.num_subgroups > 1: - results.update(self._compute_average_metrics(self.metrics)) - return results - - -def compile_model( - model: tf.keras.Model, model_params: models.ModelTrainingParameters -): - """Compiles model with optimizer, custom loss functions, and metrics.""" - if model_params.optimizer == 'adam': - optimizer = tf.keras.optimizers.Adam( - learning_rate=model_params.learning_rate - ) - else: # sgd - optimizer = tf.keras.optimizers.SGD( - learning_rate=model_params.learning_rate, momentum=0.9 - ) - loss = { - 'main': tf.keras.losses.CategoricalCrossentropy( - from_logits=False, name='main' - ) - } - loss_weights = {'main': 1} - - main_metrics = [ - tf.keras.metrics.CategoricalAccuracy(name='acc'), - tf.keras.metrics.AUC(name='auc'), - ] - for i in range(model_params.num_classes): - main_metrics.append( - metrics_lib.OneVsRest( - tf.keras.metrics.AUC(name=f'auroc_{i}_vs_rest'), i - ) - ) - main_metrics.append( - metrics_lib.OneVsRest( - tf.keras.metrics.AUC(name=f'aucpr_{i}_vs_rest', curve='PR'), - i, - ) - ) - metrics = {'main': main_metrics} - if model_params.train_bias: - metrics['bias'] = [ - tf.keras.metrics.CategoricalAccuracy(name='acc'), - tf.keras.metrics.AUC(name='auc'), - ] - loss['bias'] = tf.keras.losses.CategoricalCrossentropy( - from_logits=False, name='bias' - ) - loss_weights['bias'] = 1 - for i in range(model_params.num_subgroups): - metrics.update( - { - '_'.join(['subgroup', str(i), 'main']): [ - tf.keras.metrics.CategoricalAccuracy(name='acc'), - ] - } - ) - if model.train_bias: - metrics.update( - { - '_'.join(['subgroup', str(i), 'bias']): [ - tf.keras.metrics.CategoricalAccuracy(name='acc'), - tf.keras.metrics.AUC(name='auc'), - ] - } - ) - model.compile( - optimizer=optimizer, loss=loss, loss_weights=loss_weights, metrics=metrics - ) - return model - - -def evaluate_model( - model: tf.keras.Model, - output_dir: str, - eval_ds: Dict[str, tf.data.Dataset], - save_model_checkpoints: bool = False, - save_best_model: bool = True, -): - """Evaluates model on given validation and/or test datasets. - - Args: - model: Keras model to be evaluated. - output_dir: Path to directory where model is saved. - eval_ds: Dictionary mapping evaluation dataset name to the dataset. - save_model_checkpoints: Boolean for saving checkpoints during training. - save_best_model: Boolean for saving best model during training. - """ - checkpoint_dir = os.path.join(output_dir, 'checkpoints') - if save_model_checkpoints and tf.io.gfile.listdir(checkpoint_dir): - best_latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) - load_status = model.load_weights(best_latest_checkpoint) - load_status.assert_consumed() - for ds_name in eval_ds.keys(): - results = model.evaluate( - eval_ds[ds_name], return_dict=True) - logging.info('Evaluation Dataset Name: %s', ds_name) - logging.info('Main Acc: %f', results['main_acc']) - logging.info('Main AUC: %f', results['main_auc']) - if model.train_bias: - logging.info('Bias Acc: %f', results['bias_acc']) - logging.info('Bias Acc: %f', results['bias_auc']) - if model.num_subgroups > 1: - for i in range(model.num_subgroups): - logging.info('Subgroup %d Acc: %f', i, - results[f'subgroup_{i}_main_acc']) - logging.info('Average Acc: %f', results['avg_acc']) - logging.info('Average Acc: %f', results['weighted_avg_acc']) - if save_best_model: - model_dir = os.path.join(output_dir, 'model') - loaded_model = tf.keras.models.load_model(model_dir) - compiled_model = compile_model( - loaded_model, loaded_model.model.model_params - ) - results = compiled_model.evaluate( - eval_ds['val'], - return_dict=True, - ) - logging.info(results) - - -def init_model( - model_params: models.ModelTrainingParameters, - experiment_name: str, - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None, -) -> tf.keras.Model: - """Initializes an TwoHeadedOutputModel with a base model. - - - Args: - model_params: Dataclass object containing model and training parameters. - experiment_name: String describing experiment to use model name. - example_id_to_bias_table: Hash table mapping example ID to bias label. - - Returns: - Initialized TwoHeadedOutputModel model. - """ - model_class = models.get_model(model_params.model_name) - base_model = model_class(model_params=model_params) - - two_head_model = TwoHeadedOutputModel( - model=base_model, - num_subgroups=model_params.num_subgroups, - subgroup_sizes=model_params.subgroup_sizes, - worst_group_label=model_params.worst_group_label, - train_bias=model_params.train_bias, - name=experiment_name, - do_reweighting=model_params.do_reweighting, - reweighting_signal=model_params.reweighting_signal, - reweighting_lambda=model_params.reweighting_lambda, - error_percentile_threshold=model_params - .reweighting_error_percentile_threshold, - num_classes=model_params.num_classes) - - if model_params.train_bias or model_params.do_reweighting: - if example_id_to_bias_table: - two_head_model.update_id_to_bias_table(example_id_to_bias_table) - - two_head_model = compile_model(two_head_model, model_params) - return two_head_model - - -def create_callbacks( - output_dir: str, - save_model_checkpoints: bool = False, - save_best_model: bool = True, - early_stopping: bool = True, - batch_size: Optional[int] = 64, - num_train_examples: Optional[int] = None, -) -> List[tf.keras.callbacks.Callback]: - """Creates callbacks, such as saving model checkpoints, for training. - - Args: - output_dir: Directory where model will be saved. - save_model_checkpoints: Boolean for whether or not to save checkpoints. - save_best_model: Boolean for whether or not to save best model. - early_stopping: Boolean for whether or not to use early stopping during - training. - batch_size: Optional integer for batch size. - num_train_examples: Optional integer for total number of training examples. - - Returns: - List of callbacks. - """ - callbacks = [] - if save_model_checkpoints: - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join( - os.path.join(output_dir, 'checkpoints'), - 'epoch-{epoch:02d}-val_auc-{val_main_auc:.2f}.ckpt'), - monitor='val_main_auc', - mode='max', - save_weights_only=True, - save_best_only=True) - callbacks.append(checkpoint_callback) - if save_best_model: - model_dir = os.path.join(output_dir, 'model') - # TODO(jihyeonlee,melfatih): Update to AUPRC. - model_callback = tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join( - model_dir, - 'aucpr-{val_main_aucpr_1_vs_rest:.2f}'), - monitor='val_main_aucpr_1_vs_rest', - mode='max', - save_weights_only=False, - save_best_only=True, - save_traces=True) - callbacks.append(model_callback) - if early_stopping: - early_stopping_callback = tf.keras.callbacks.EarlyStopping( - monitor='val_main_aucpr_1_vs_rest', - min_delta=0.001, - patience=30, - verbose=1, - mode='max', - baseline=None, - restore_best_weights=True - ) - callbacks.append(early_stopping_callback) - return callbacks - - -def run_train( - train_ds: tf.data.Dataset, - val_ds: tf.data.Dataset, - model_params: models.ModelTrainingParameters, - experiment_name: str, - callbacks: Optional[List[tf.keras.callbacks.Callback]] = None, - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None -) -> tf.keras.Model: - """Initializes and trains model on given training and validation data. - - Args: - train_ds: Training dataset. - val_ds: Evaluation dataset. - model_params: Dataclass object containing model and training parameters. - experiment_name: String to describe model being trained. - callbacks: Keras Callbacks, like saving checkpoints or early stopping. - example_id_to_bias_table: Hash table mapping example ID to bias label. - - Returns: - Trained model. - """ - two_head_model = init_model( - model_params=model_params, - experiment_name=experiment_name, - example_id_to_bias_table=example_id_to_bias_table - ) - - two_head_model.fit( - train_ds, - validation_data=val_ds, - epochs=model_params.num_epochs, - callbacks=callbacks) - return two_head_model - - -def train_ensemble( - dataloader: data.Dataloader, - model_params: models.ModelTrainingParameters, - num_splits: int, - ood_ratio: float, - output_dir: str, - save_model_checkpoints: bool = True, - early_stopping: bool = True, - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None -) -> List[tf.keras.Model]: - """Trains an ensemble of models, locally. See xm_launch.py for parallelized. - - Args: - dataloader: Dataclass object containing training and validation data. - model_params: Dataclass object containing model and training parameters. - num_splits: Integer number for total slices of dataset. - ood_ratio: Float for the ratio of slices that will be considered - out-of-distribution. - output_dir: String for directory path where checkpoints will be saved. - save_model_checkpoints: Boolean for saving checkpoints during training. - early_stopping: Boolean for early stopping during training. - example_id_to_bias_table: Hash table mapping example ID to bias label. - - Returns: - List of trained models and, optionally, predictions. - """ - num_ood_splits = int(num_splits * ood_ratio) - num_id_splits = num_splits - num_ood_splits - train_idx_combos = [ - list(c) for c in list( - itertools.combinations(range(num_splits), num_id_splits)) - ] - ensemble = [] - for combo in train_idx_combos: - combo_name = '_'.join(map(str, combo)) - combo_train = data.gather_data_splits(combo, dataloader.train_splits) - combo_val = data.gather_data_splits(combo, dataloader.val_splits) - combo_ckpt_dir = os.path.join(output_dir, combo_name, 'checkpoints') - combo_callbacks = create_callbacks(combo_ckpt_dir, save_model_checkpoints, - early_stopping) - combo_model = run_train( - combo_train, - combo_val, - model_params=model_params, - experiment_name=combo_name, - callbacks=combo_callbacks, - example_id_to_bias_table=example_id_to_bias_table) - ensemble.append(combo_model) - return ensemble - - -def find_epoch_ckpt_path(epoch: int, - ckpt_dir: str, - metric_name: str = 'val_auc', - mode: str = 'highest') -> Union[str, List[str]]: - r"""Finds the checkpoints for a given epoch. - - This function extracts the checkpoints corresponding to a given epoch under - `ckpt_dir`. It assumes the checkpoints follows the naming convention: - - `{ckpt_dir}\epoch-{epoch}-{metric_name}-{metric_val}.ckpt` - - If a checkpoint for a given epoch is not found, it will issue a warning and - return the checkpoint for the nearest epoch instead. - - Args: - epoch: The epoch to exact checkpoints for. - ckpt_dir: The directory of checkpoints. - metric_name: The name of the performance metric. - mode: The return mode. One of ('highest', 'lowest', 'all'). Here, 'highest' - / 'lowest' means if there are multiple checkpoint for the required epoch, - return the checkpoint with the highest / lowest value for the metric. - - Returns: - Strings for checkpoint directories. - """ - if mode not in ('highest', 'lowest', 'all'): - raise ValueError( - f'mode `{mode}` not supported. Should be one of ("best", "all").') - - # Collects checkpoint names. - ckpt_names = [ - f_name.split('.ckpt')[0] - for f_name in tf.io.gfile.listdir(ckpt_dir) - if '.ckpt.index' in f_name - ] - - if not ckpt_names: - raise ValueError(f'No valid checkpoint under the directory {ckpt_dir}.') - - # Extract epoch number and metric values. - ckpt_epochs = np.array( - [int(f_name.split('epoch-')[1].split('-')[0]) for f_name in ckpt_names]) - ckpt_metric = np.array([ - float(f_name.split(f'{metric_name}-')[1].split('-')[0]) - for f_name in ckpt_names - ]) - - if epoch not in ckpt_epochs: - # Uses nearest available epoch in `ckpt_epochs`. - nearest_epoch_id = np.argmin(np.abs(ckpt_epochs - epoch)) - nearest_epoch = ckpt_epochs[nearest_epoch_id] - tf.compat.v1.logging.warn( - 'Required epoch (%s) not in list of available epochs `%s`.' - 'Use nearest epoch `%s`', epoch, np.unique(ckpt_epochs), nearest_epoch) - epoch = nearest_epoch - - make_ckpt_path = lambda name: os.path.join(ckpt_dir, name + '.ckpt') - if mode == 'highest': - # Returns the checkpoint with highest metric value. - ckpt_id = np.argmax(ckpt_metric * (ckpt_epochs == epoch)) - return make_ckpt_path(ckpt_names[ckpt_id]) - elif mode == 'lowest': - # Returns the checkpoint with lowest metric value. - ckpt_id = np.argmin(-ckpt_metric * (ckpt_epochs == epoch)) - return make_ckpt_path(ckpt_names[ckpt_id]) - else: - # Returns all the checkpoints. - ckpt_ids = np.where(ckpt_epochs == epoch)[0] - return [make_ckpt_path(ckpt_names[ckpt_id]) for ckpt_id in ckpt_ids] - - -def load_trained_models(combos_dir: str, - model_params: models.ModelTrainingParameters, - ckpt_epoch: int = -1): - """Loads models trained on different combinations of data splits. - - Args: - combos_dir: Path to the checkpoint trained on different data splits. - model_params: Model config. - ckpt_epoch: The epoch to load the checkpoint from. If negative, load the - latest checkpoint. - - Returns: - The list of loaded models for different combinations of data splits. - """ - trained_models = [] - for combo_name in tf.io.gfile.listdir(combos_dir): - - ckpt_dir = os.path.join(combos_dir, combo_name, 'checkpoints') - if ckpt_epoch < 0: - # Loads the latest checkpoint. - checkpoint_path = tf.train.latest_checkpoint(ckpt_dir) - tf.compat.v1.logging.info(f'Loading best model from `{checkpoint_path}`') - else: - # Loads the required checkpoint. - # By default, select the checkpoint with highest validation AUC. - checkpoint_path = find_epoch_ckpt_path( - ckpt_epoch, ckpt_dir, metric_name='val_auc', mode='highest') - tf.compat.v1.logging.info( - f'Loading model for checkpoint {ckpt_epoch} from `{checkpoint_path}`') - - combo_model = load_one_checkpoint(checkpoint_path=checkpoint_path, - model_params=model_params, - experiment_name=combo_name) - trained_models.append(combo_model) - return trained_models - - -def load_one_checkpoint( - checkpoint_path: str, - model_params: models.ModelTrainingParameters, - experiment_name: str, -) -> tf.keras.Model: - """Loads a model checkpoint. - - Args: - checkpoint_path: Path to checkpoint - model_params: Model training parameters - experiment_name: Name of experiment - - Returns: - A model checkpoint. - """ - if not tf.io.gfile.exists(checkpoint_path + '.index'): - raise ValueError( - f'Required checkpoint file `{checkpoint_path}` not exist.') - - model = init_model( - model_params=model_params, - experiment_name=experiment_name) - load_status = model.load_weights(checkpoint_path) - # Optimizer will not be loaded (https://b.corp.google.com/issues/124099628), - # so expect only partial load. This is not currently an issue because - # model is only used for inference. - load_status.expect_partial() - load_status.assert_existing_objects_matched() - return model - - -# TODO(martinstrobel): Merge this function with `find_epoch_ckpt_path`. -def generate_checkpoint_list( - checkpoint_dir: str, - checkpoint_list: Optional[List[str]] = None, - checkpoint_selection: Optional[str] = 'first', - checkpoint_number: Optional[int] = 5, - checkpoint_name: Optional[str] = '', -) -> Optional[List[str]]: - """Creates a list of checkpoints to load. - - Args: - checkpoint_dir: Path to the checkpoint directory. - checkpoint_list: List of checkpoint names (only used when checkpoint - selection is list) - checkpoint_selection: Mode of how to select checkpoints. - 'first': Select the first x checkpoints by epoch - 'last' : Select the las x checkpoints by epoch. - 'spread': Select x checlpoints spread out evenly over all epochs - 'list': Select the checkpoints provided in checkpoint_list. - 'name': Select the named checkpoint - checkpoint_number: Number of chekcpoints returned (only used when checkpoint - selection is first, last, or spread) - checkpoint_name: Name of a single checkpoint (only used when a checkpoint - selection is name) - Returns: - List of checkpoints to load - """ - if checkpoint_selection != 'list': - ckpts_names = tf.io.gfile.listdir(checkpoint_dir) - ckpts_names = list(filter(lambda x: '.ckpt.index' in x, ckpts_names)) - ckpts_names = list(map(lambda x: x[:-6], ckpts_names)) - epochs = [int(ckpt_name.split('-')[1]) for ckpt_name in ckpts_names] - sorted_ckpts_names = [ckpts_names[i] for i in np.argsort(epochs)] - if checkpoint_selection == 'first': - checkpoint_list = sorted_ckpts_names[:checkpoint_number] - elif checkpoint_selection == 'last' and checkpoint_number: - checkpoint_list = sorted_ckpts_names[int(-checkpoint_number):] - elif checkpoint_selection == 'spread': - checkpoint_list = [ - sorted_ckpts_names[i] - for i in range(0, len(sorted_ckpts_names), - int(len(sorted_ckpts_names) / checkpoint_number)) - ] - elif checkpoint_selection == 'all': - checkpoint_list = sorted_ckpts_names - elif checkpoint_selection == 'name': - checkpoint_list = [str(checkpoint_name)] - return checkpoint_list - - -# TODO(martinstrobel): Merge this function with `load_trained_models`. -def load_model_checkpoints(checkpoint_dir: str, - model_params: models.ModelTrainingParameters, - checkpoint_list: Optional[List[str]], - checkpoint_selection: Optional[str] = 'first', - checkpoint_number: Optional[int] = 5, - checkpoint_name: Optional[str] = '', - ) -> List[tf.keras.Model]: - """Loads model checkpoints from a given checkpoint directory. - - Args: - checkpoint_dir: Path to the checkpoint directory. - model_params: Model training parameters - checkpoint_list: List of checkpoint names (only used when checkpoint - selection is list) - checkpoint_selection: Mode of how to select checkpoints. - 'first': Select the first x checkpoints by epoch - 'last' : Select the las x checkpoints by epoch. - 'spread': Select x checlpoints spread out evenly over all epochs - 'list': Select the checkpoints provided in checkpoint_list. - checkpoint_number: Number of chekcpoints returned (only used when checkpoint - selection is first, last, or spread) - checkpoint_name: Name of a single checkpoint (only used when a checkpoint - selection is name) - - Returns: - A list of model checkpoints. - """ - checkpoint_list = generate_checkpoint_list(checkpoint_dir, checkpoint_list, - checkpoint_selection, - checkpoint_number, checkpoint_name) - checkpoints = [] - for checkpoint in checkpoint_list: - ckpt_path = os.path.join(checkpoint_dir, checkpoint) - ckpt = load_one_checkpoint(checkpoint_path=ckpt_path, - model_params=model_params, - experiment_name='') - checkpoints.append(ckpt) - return checkpoints - - -def eval_ensemble( - dataloader: data.Dataloader, - ensemble: List[tf.keras.Model], - example_id_to_bias_table: tf.lookup.StaticHashTable): - """Calculates the average predictions of the ensemble for evaluation. - - Args: - dataloader: Dataclass object containing training and validation data. - ensemble: List of trained models. - example_id_to_bias_table: Hash table mapping example ID to bias label. - """ - for ds_name in dataloader.eval_ds.keys(): - test_examples = dataloader.eval_ds[ds_name] - y_pred_main = [] - y_pred_bias = [] - for model in ensemble: - ensemble_prob_samples = model.predict( - test_examples.map(lambda x: x['input_feature'])) - y_pred_main.append(ensemble_prob_samples['main']) - y_pred_bias.append(ensemble_prob_samples['bias']) - y_pred_main = tf.reduce_mean(y_pred_main, axis=0) - y_pred_bias = tf.reduce_mean(y_pred_bias, axis=0) - y_true_main = list(test_examples.map( - lambda x: x['label']).as_numpy_iterator()) - y_true_main = tf.concat(y_true_main, axis=0) - y_true_main = tf.convert_to_tensor(y_true_main, dtype=tf.int64) - y_true_main = tf.one_hot(y_true_main, depth=2) - example_ids = list(test_examples.map( - lambda x: x['example_id']).as_numpy_iterator()) - example_ids = tf.concat(example_ids, axis=0) - example_ids = tf.convert_to_tensor(example_ids, dtype=tf.string) - y_true_bias = example_id_to_bias_table.lookup(example_ids) - y_true_bias = tf.one_hot(y_true_bias, depth=2) - for m in ensemble[0].metrics: - m.reset_state() - ensemble[0].compiled_metrics.update_state({ - 'main': y_true_main, - 'bias': y_true_bias - }, { - 'main': y_pred_main, - 'bias': y_pred_bias - }) - result = {m.name: m.result() for m in ensemble[0].metrics} - logging.info('Evaluation Dataset Name: %s', ds_name) - logging.info('Main Acc: %f', result['main_acc']) - logging.info('Main AUC: %f', result['main_auc']) - # TODO(jihyeonlee): Bias labels are not calculated for other evaluation - # datasets beyond validation, e.g. 'test' or 'test2'. - # Provide way to save the predictions themselves. - logging.info('Bias Acc: %f', result['bias_acc']) - logging.info('Bias AUC: %f', result['bias_auc']) - - -def run_ensemble( - dataloader: data.Dataloader, - model_params: models.ModelTrainingParameters, - num_splits: int, - ood_ratio: float, - output_dir: str, - save_model_checkpoints: bool = True, - early_stopping: bool = True, - ensemble_dir: Optional[str] = '', - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None -) -> List[tf.keras.Model]: - """Trains an ensemble of models and optionally gets their average predictions. - - Args: - dataloader: Dataclass object containing training and validation data. - model_params: Dataclass object containing model and training parameters. - num_splits: Integer number for total slices of dataset. - ood_ratio: Float for the ratio of slices that will be considered - out-of-distribution. - output_dir: String for directory path where checkpoints will be saved. - save_model_checkpoints: Boolean for saving checkpoints during training. - early_stopping: Boolean for early stopping during training. - ensemble_dir: Optional string for a directory that stores trained model - checkpoints. If specified, will load the models from directory. - example_id_to_bias_table: Hash table mapping example ID to bias label. - - Returns: - List of trained models and, optionally, predictions. - """ - - if ensemble_dir: - ensemble = load_trained_models(ensemble_dir, model_params) - else: - ensemble = train_ensemble(dataloader, model_params, num_splits, ood_ratio, - output_dir, save_model_checkpoints, - early_stopping, example_id_to_bias_table) - if dataloader.eval_ds and example_id_to_bias_table: - eval_ensemble(dataloader, ensemble, example_id_to_bias_table) - - return ensemble - - -def train_and_evaluate( - train_as_ensemble: bool, - dataloader: data.Dataloader, - model_params: models.ModelTrainingParameters, - num_splits: int, - ood_ratio: float, - output_dir: str, - experiment_name: str, - save_model_checkpoints: bool, - save_best_model: bool, - early_stopping: bool, - ensemble_dir: Optional[str] = '', - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None): - """Performs the operations of training, optionally ensembling, and evaluation. - - Args: - train_as_ensemble: Boolean for whether or not to train an ensemble of - models. Also performs evaluation for ensemble. - dataloader: Dataclass object containing training and validation data. - model_params: Dataclass object containing model parameters. - num_splits: Integer for number of data splits. - ood_ratio: Float for ratio of splits to consider as out-of-distribution. - output_dir: Path to directory where model will be saved. - experiment_name: String describing experiment. - save_model_checkpoints: Boolean for saving checkpoints during training. - save_best_model: Boolean for saving best model during training. - early_stopping: Boolean for early stopping during training. - ensemble_dir: Optional string for a directory that stores trained model - checkpoints. If specified, will load the models from directory. - example_id_to_bias_table: Lookup table mapping example ID to bias label. - - Returns: - Trained Model(s) - """ - if train_as_ensemble: - return run_ensemble( - dataloader=dataloader, - model_params=model_params, - num_splits=num_splits, - ood_ratio=ood_ratio, - output_dir=output_dir, - save_model_checkpoints=save_model_checkpoints, - early_stopping=early_stopping, - ensemble_dir=ensemble_dir, - example_id_to_bias_table=example_id_to_bias_table) - else: - callbacks = create_callbacks( - output_dir, - save_model_checkpoints, - save_best_model, - early_stopping, - model_params.batch_size, - dataloader.num_train_examples) - - two_head_model = run_train( - dataloader.train_ds, - dataloader.eval_ds['val'], - model_params=model_params, - experiment_name=experiment_name, - callbacks=callbacks, - example_id_to_bias_table=example_id_to_bias_table) - evaluate_model(two_head_model, output_dir, dataloader.eval_ds, - save_model_checkpoints, save_best_model) - return two_head_model diff --git a/experimental/shoshin/train_tf_sequential_active.py b/experimental/shoshin/train_tf_sequential_active.py deleted file mode 100644 index d6a475943..000000000 --- a/experimental/shoshin/train_tf_sequential_active.py +++ /dev/null @@ -1,208 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Binary to run active sampling locally in a sequential manner. - - -# pylint: enable=line-too-long -""" - -import logging as native_logging -import os - -from absl import app -from absl import flags -from absl import logging -from ml_collections import config_flags -import tensorflow as tf -import data # local file import from experimental.shoshin -import generate_bias_table_lib # local file import from experimental.shoshin -import models # local file import from experimental.shoshin -import sampling_policies # local file import from experimental.shoshin -import train_tf_lib # local file import from experimental.shoshin -from configs import base_config # local file import from experimental.shoshin - - -FLAGS = flags.FLAGS -config_flags.DEFINE_config_file('config') -flags.DEFINE_bool('keep_logs', True, 'If True, creates a log file in output ' - 'directory. If False, only logs to console.') -flags.DEFINE_string('ensemble_dir', '', 'If specified, loads the models at ' - 'this directory to consider the ensemble.') - - -def main(_) -> None: - config = FLAGS.config - base_config.check_flags(config) - - dataset_builder = data.get_dataset(config.data.name) - if FLAGS.keep_logs: - tf.io.gfile.makedirs(config.output_dir) - stream = tf.io.gfile.GFile(os.path.join(config.output_dir, 'log'), mode='w') - stream_handler = native_logging.StreamHandler(stream) - logging.get_absl_logger().addHandler(stream_handler) - - logging.info(config) - - output_dir = config.output_dir - ids_dir = '' - # Train only the main task without a bias head or rounds of active learning. - if config.train_single_model: - num_rounds = 1 - else: - num_rounds = config.num_rounds - - # If called with config.round_idx < 0, then go through all rounds else - # only go through config.round_idx-th round - round_ids = range(num_rounds) - if config.round_idx >= 0: - round_ids = [config.round_idx] - - for round_idx in [round_ids]: - logging.info('Running Round %d of Training.', round_idx) - if round_idx == 0: - # If initial round of sampling, sample randomly initial_sample_proportion - dataloader = dataset_builder(config.data.num_splits, - config.initial_sample_proportion, - config.subgroup_ids, - config.subgroup_proportions) - else: - # If latter round, keep track of split generated in last round of active - # sampling - dataloader = dataset_builder(config.data.num_splits, - 1, - None, None) - # Filter each split to only have examples from example_ids_table - dataloader.train_splits = [ - dataloader.train_ds.filter( - generate_bias_table_lib.filter_ids_fn(example_ids_tab)) - for example_ids_tab in sampling_policies.convert_ids_to_table(ids_dir) - ] - # Apply batching (must apply batching only after filtering) - dataloader = data.apply_batch(dataloader, config.data.batch_size) - model_params = models.ModelTrainingParameters( - model_name=config.model.name, - train_bias=config.train_bias, - num_classes=config.data.num_classes, - num_subgroups=dataloader.num_subgroups, - subgroup_sizes=dataloader.subgroup_sizes, - num_epochs=config.training.num_epochs, - learning_rate=config.optimizer.learning_rate, - ) - if not config.train_single_model: - output_dir = os.path.join(config.output_dir, f'round_{round_idx}') - ids_dir = os.path.join(output_dir, 'ids') - tf.io.gfile.makedirs(output_dir) - - if config.train_bias: - # Bias head will be trained as well, so gets bias labels. - if config.path_to_existing_bias_table: - example_id_to_bias_table = ( - generate_bias_table_lib.load_existing_bias_table( - config.path_to_existing_bias_table - ) - ) - else: - logging.info( - 'Training models on different splits of data to calculate bias...') - model_params.train_bias = False - combos_dir = os.path.join(output_dir, - generate_bias_table_lib.COMBOS_SUBDIR) - _ = train_tf_lib.run_ensemble( - dataloader=dataloader, - model_params=model_params, - num_splits=config.data.num_splits, - ood_ratio=config.data.ood_ratio, - output_dir=combos_dir, - save_model_checkpoints=config.training.save_model_checkpoints, - early_stopping=config.training.early_stopping) - trained_models = train_tf_lib.load_trained_models( - combos_dir, model_params - ) - example_id_to_bias_table = ( - generate_bias_table_lib.get_example_id_to_bias_label_table( - dataloader=dataloader, - combos_dir=combos_dir, - trained_models=trained_models, - num_splits=config.data.num_splits, - bias_value_threshold=config.bias_value_threshold, - tracin_value_threshold=config.tracin_value_threshold, - bias_percentile_threshold=config.bias_percentile_threshold, - tracin_percentile_threshold=config.tracin_percentile_threshold, - save_dir=output_dir, - save_table=config.save_bias_table, - ) - ) - model_params.train_bias = config.train_bias - if config.train_bias and config.data.included_splits_idx: - # Likely training a single model on a combination of data splits. - included_splits_idx = [int(i) for i in config.data.included_splits_idx] - train_ds = data.gather_data_splits(included_splits_idx, - dataloader.train_splits) - val_ds = data.gather_data_splits(included_splits_idx, - dataloader.val_splits) - dataloader.train_ds = train_ds - dataloader.eval_ds['val'] = val_ds - - trained_stagetwo_models = train_tf_lib.train_and_evaluate( - train_as_ensemble=config.train_stage_2_as_ensemble, - dataloader=dataloader, - model_params=model_params, - num_splits=config.data.num_splits, - ood_ratio=config.data.ood_ratio, - output_dir=output_dir, - experiment_name='stage_2', - save_model_checkpoints=config.training.save_model_checkpoints, - save_best_model=config.training.save_best_model, - early_stopping=config.training.early_stopping, - ensemble_dir=FLAGS.ensemble_dir, - example_id_to_bias_table=example_id_to_bias_table) - - # Get all ids used for training - ids_train = data.get_ids_from_dataset(dataloader.train_ds) - # Get predictions from trained models on whole dataset - dataloader = dataset_builder( - config.data.num_splits, - 1, - config.subgroup_ids, - config.subgroup_proportions) - dataloader.train_splits = [ - split.filter( - generate_bias_table_lib.filter_ids_fn(example_id_to_bias_table, 0)) - for split in dataloader.train_splits - ] - dataloader = data.apply_batch(dataloader, config.data.batch_size) - predictions_df = ( - generate_bias_table_lib.get_example_id_to_predictions_table( - dataloader, - trained_stagetwo_models, - config.train_bias, - save_dir=output_dir, - save_table=config.save_bias_table, - ) - ) - # Compute new ids to sample and append to initial set of ids - next_ids_dir = os.path.join(config.output_dir, f'round_{round_idx}/ids') - _ = sampling_policies.sample_and_split_ids( - ids_train, predictions_df, config.active_sampling.sampling_score, - config.active_sampling.num_samples_per_round, config.data.num_splits, - next_ids_dir, True - ) - # TODO(jihyeonlee): Will add Waterbirds dataloader and ResNet model to support - # vision modality. - - -if __name__ == '__main__': - app.run(main)