Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
update readme to include known issues
Browse files Browse the repository at this point in the history
  • Loading branch information
chentingpc committed Jul 6, 2021
1 parent 3ad6700 commit dec99a8
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 37 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SimCLR - A Simple Framework for Contrastive Learning of Visual Representations

<span style="color: red"><strong>News! </strong></span> We have released a TF2 implementation of SimCLR (along with converted checkpoints in TF2), they are in <a href="https://github.com/google-research/simclr/tree/master/tf2">tf2/ folder</a>.
<span style="color: red"><strong>News! </strong></span> We have released a TF2 implementation of SimCLR (along with converted checkpoints in TF2), they are in <a href="tf2/">tf2/ folder</a>.

<span style="color: red"><strong>News! </strong></span> Colabs for <a href="https://arxiv.org/abs/2011.02803">Intriguing Properties of Contrastive Losses</a> are added, see <a href="https://github.com/google-research/simclr/tree/master/colabs/intriguing_properties">here</a>.
<span style="color: red"><strong>News! </strong></span> Colabs for <a href="https://arxiv.org/abs/2011.02803">Intriguing Properties of Contrastive Losses</a> are added, see <a href="colabs/intriguing_properties/">here</a>.

<div align="center">
<img width="50%" alt="SimCLR Illustration" src="https://1.bp.blogspot.com/--vH4PKpE9Yo/Xo4a2BYervI/AAAAAAAAFpM/vaFDwPXOyAokAC8Xh852DzOgEs22NhbXwCLcBGAsYHQ/s1600/image4.gif">
Expand All @@ -12,7 +12,7 @@
</div>

## Pre-trained models for SimCLRv2
<a href="https://colab.research.google.com/github/google-research/simclr/blob/master/colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<a href="colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

We opensourced total 65 pretrained models here, corresponding to those in Table 1 of the <a href="https://arxiv.org/abs/2006.10029">SimCLRv2</a> paper:

Expand Down Expand Up @@ -170,7 +170,7 @@ Set the `checkpoint` to those that are only pre-trained but not fine-tuned. Give

## Other resources

### Model convertion to Pytorch format
### Model conversion to Pytorch format

This [repo](https://github.com/tonylins/simclr-converter) provides a solution for converting the pretrained SimCLRv1 Tensorflow checkpoints into Pytorch ones.

Expand All @@ -187,10 +187,16 @@ Implementations in PyTorch:
* [Spijkervet](https://github.com/Spijkervet/SimCLR)
* [williamFalcon](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)

Implementations in Tensorflow 2 / Keras (official TF2 implementation was added in <a href="https://github.com/google-research/simclr/tree/master/tf2">tf2/ folder</a>):
Implementations in Tensorflow 2 / Keras (official TF2 implementation was added in <a href="tf2/">tf2/ folder</a>):
* [sayakpaul](https://github.com/sayakpaul/SimCLR-in-TensorFlow-2)
* [mwdhont](https://github.com/mwdhont/SimCLRv1-keras-tensorflow)

## Known issues

* **Batch size**: original results of SimCLR were tuned under a large batch size (i.e. 4096), which leads to suboptimal results when training using a smaller batch size. However, with a good set of hyper-parameters (mainly learning rate, temperature, projection head depth), small batch sizes can yield results that are on par with large batch sizes (e.g., see Table 2 in [this paper](https://arxiv.org/pdf/2011.02803.pdf)).

* **Pretrained models / Checkpoints**: SimCLRv1 and SimCLRv2 are pretrained with different weight decays, so the pretrained models from the two versions have very different weight norm scales (convolutional weights in SimCLRv1 ResNet-50 are on average 16.8X of that in SimCLRv2). For fine-tuning the pretrained models from both versions, it is fine if you use an LARS optimizer, but it requires very different hyperparameters (e.g. learning rate, weight decay) if you use the momentum optimizer. So for the latter case, you may want to either search for very different hparams according to which version used, or re-scale th weight (i.e. conv `kernel` parameters of `base_model` in the checkpoints) to make sure they're roughly in the same scale.

## Cite

[SimCLR paper](https://arxiv.org/abs/2002.05709):
Expand Down
2 changes: 1 addition & 1 deletion tf2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This implementation is based on TensorFlow 2.x. We use `tf.keras` layers for bui
<br/><br/>

## Pre-trained models for SimCLRv2
<a href="https://colab.research.google.com/github/google-research/simclr/blob/master/tf2/colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<a href="tf2/colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

We have converted the checkpoints for the TF1 models of SimCLR v1 and v2 to TF2 [SavedModel](https://www.tensorflow.org/guide/saved_model):

Expand Down
24 changes: 15 additions & 9 deletions tf2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import data_util
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -60,27 +61,32 @@ def map_fn(image, label):
label = tf.one_hot(label, num_classes)
return image, label

logging.info('num_input_pipelines: %d', input_context.num_input_pipelines)
dataset = builder.as_dataset(
split=FLAGS.train_split if is_training else FLAGS.eval_split,
shuffle_files=is_training,
as_supervised=True)
logging.info('num_input_pipelines: %d', input_context.num_input_pipelines)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_context.num_input_pipelines > 1:
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
as_supervised=True,
# Passing the input_context to TFDS makes TFDS read different parts
# of the dataset on different workers. We also adjust the interleave
# parameters to achieve better performance.
read_config=tfds.ReadConfig(
interleave_cycle_length=32,
interleave_block_length=1,
input_context=input_context))
if FLAGS.cache_dataset:
dataset = dataset.cache()
if is_training:
options = tf.data.Options()
options.experimental_deterministic = False
options.experimental_slack = True
dataset = dataset.with_options(options)
buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
dataset = dataset.shuffle(batch_size * buffer_multiplier)
dataset = dataset.repeat(-1)
dataset = dataset.map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
prefetch_buffer_size = 2 * topology.num_tpus_per_task if topology else 2
dataset = dataset.prefetch(prefetch_buffer_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset

return _input_fn
Expand Down
8 changes: 5 additions & 3 deletions tf2/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,12 @@ def _transform(image): # pylint: disable=missing-docstring
return random_apply(_transform, p=p, x=image)


def random_color_jitter(image, p=1.0, impl='simclrv2'):
def random_color_jitter(image, p=1.0, strength=1.0,
impl='simclrv2'):

def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=FLAGS.color_jitter_strength, impl=impl)
color_jitter, strength=strength, impl=impl)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
Expand Down Expand Up @@ -469,7 +470,8 @@ def preprocess_for_train(image,
if flip:
image = tf.image.random_flip_left_right(image)
if color_distort:
image = random_color_jitter(image, impl=impl)
image = random_color_jitter(image, strength=FLAGS.color_jitter_strength,
impl=impl)
image = tf.reshape(image, [height, width, 3])
image = tf.clip_by_value(image, 0., 1.)
return image
Expand Down
26 changes: 13 additions & 13 deletions tf2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,22 @@ def __init__(self,
# However, it is still used for batch norm.
super(LinearLayer, self).__init__(**kwargs)
self.num_classes = num_classes
self.use_bias = use_bias
self.use_bn = use_bn
self._name = name
if callable(self.num_classes):
num_classes = -1
else:
num_classes = self.num_classes
self.dense = tf.keras.layers.Dense(
num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
use_bias=use_bias and not self.use_bn)
if self.use_bn:
self.bn_relu = resnet.BatchNormRelu(relu=False, center=use_bias)

def build(self, input_shape):
# TODO(srbs): Add a new SquareDense layer.
if callable(self.num_classes):
self.dense.units = self.num_classes(input_shape)
num_classes = self.num_classes(input_shape)
else:
num_classes = self.num_classes
self.dense = tf.keras.layers.Dense(
num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
use_bias=self.use_bias and not self.use_bn)
super(LinearLayer, self).build(input_shape)

def call(self, inputs, training):
Expand Down Expand Up @@ -242,13 +241,14 @@ def __init__(self, num_classes, **kwargs):
def __call__(self, inputs, training):
features = inputs
if training and FLAGS.train_mode == 'pretrain':
num_transforms = 2
if FLAGS.fine_tune_after_block > -1:
raise ValueError('Does not support layer freezing during pretraining,'
'should set fine_tune_after_block<=-1 for safety.')
else:
num_transforms = 1

if inputs.shape[3] is None:
raise ValueError('The input channels dimension must be statically known '
f'(got input shape {inputs.shape})')
num_transforms = inputs.shape[3] // 3
num_transforms = tf.repeat(3, num_transforms)
# Split channels, and optionally apply extra batched augmentation.
features_list = tf.split(
features, num_or_size_splits=num_transforms, axis=-1)
Expand Down
15 changes: 9 additions & 6 deletions tf2/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
'Whether or not to use Gaussian blur for augmentation during pretraining.')


def get_salient_tensors_dict():
def get_salient_tensors_dict(include_projection_head):
"""Returns a dictionary of tensors."""
graph = tf.compat.v1.get_default_graph()
result = {}
Expand All @@ -252,11 +252,15 @@ def get_salient_tensors_dict():
result['final_avg_pool'] = graph.get_tensor_by_name('resnet/final_avg_pool:0')
result['logits_sup'] = graph.get_tensor_by_name(
'head_supervised/logits_sup:0')

if include_projection_head:
result['proj_head_input'] = graph.get_tensor_by_name(
'projection_head/proj_head_input:0')
result['proj_head_output'] = graph.get_tensor_by_name(
'projection_head/proj_head_output:0')
return result


def build_saved_model(model):
def build_saved_model(model, include_projection_head=True):
"""Returns a tf.Module for saving to SavedModel."""

class SimCLRModel(tf.Module):
Expand All @@ -271,7 +275,7 @@ def __init__(self, model):
@tf.function
def __call__(self, inputs, trainable):
self.model(inputs, training=trainable)
return get_salient_tensors_dict()
return get_salient_tensors_dict(include_projection_head)

module = SimCLRModel(model)
input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
Expand Down Expand Up @@ -423,8 +427,7 @@ def run_single_step(iterator):
json.dump(serializable_flags, f)

# Export as SavedModel for finetuning and inference.
if FLAGS.train_mode == 'finetune':
save(model, global_step=result['global_step'])
save(model, global_step=result['global_step'])

return result

Expand Down

0 comments on commit dec99a8

Please sign in to comment.