Skip to content

Commit

Permalink
Merge pull request #97 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v0.0.82
  • Loading branch information
Kevin Musgrave authored Dec 1, 2022
2 parents 073c20e + 7227732 commit 3cef36f
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 75 deletions.
13 changes: 8 additions & 5 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
[flake8]

extend-ignore =
E266 # too many leading '#' for block comment
E203 # whitespace before ':'
E402 # module level import not at top of file
E501 # line too long

# too many leading '#' for block comment
E266
# whitespace before ':'
E203
# module level import not at top of file
E402
# line too long
E501
per-file-ignores =
__init__.py:F401
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
extras_require_detection = ["albumentations >= 1.2.1"]
extras_require_ignite = ["pytorch-ignite == 0.4.9"]
extras_require_lightning = ["pytorch-lightning"]
extras_require_record_keeper = ["record-keeper >= 0.9.32"]
extras_require_record_keeper = ["record-keeper >= 0.9.32", "tensorboard"]
extras_require_timm = ["timm"]
extras_require_docs = [
"mkdocs-material",
Expand Down Expand Up @@ -44,8 +44,8 @@
"numpy",
"torch",
"torchvision",
"torchmetrics >= 0.9.3",
"pytorch-metric-learning >= 1.5.2",
"torchmetrics == 0.9.3",
"pytorch-metric-learning >= 1.6.3",
],
extras_require={
"detection": extras_require_detection,
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_adapt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.81"
__version__ = "0.0.82"
20 changes: 14 additions & 6 deletions src/pytorch_adapt/hooks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
detach_features: bool = False,
f_hook: BaseHook = None,
domains=("src",),
**kwargs,
):
"""
Expand All @@ -75,21 +76,28 @@ def __init__(
self.hook = c_f.default(
f_hook,
FeaturesAndLogitsHook,
{"domains": ["src"], "detach_features": detach_features},
{"domains": domains, "detach_features": detach_features},
)

def call(self, inputs, losses):
""""""
outputs = self.hook(inputs, losses)[0]
[src_logits] = c_f.extract(
[outputs, inputs], c_f.filter(self.hook.out_keys, "_logits$")
output_losses = {}
logits = c_f.extract(
[outputs, inputs],
c_f.filter(
self.hook.out_keys, "_logits$", [f"^{d}" for d in self.hook.domains]
),
)
loss = self.loss_fn(src_logits, inputs["src_labels"])
return outputs, {self._loss_keys()[0]: loss}
for i, d in enumerate(self.hook.domains):
output_losses[self._loss_keys()[i]] = self.loss_fn(
logits[i], inputs[f"{d}_labels"]
)
return outputs, output_losses

def _loss_keys(self):
""""""
return ["c_loss"]
return [f"{d}_c_loss" for d in self.hook.domains]


class ClassifierHook(BaseWrapperHook):
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_adapt/hooks/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
):
for i in range(len(hooks) - 1):
hooks[i + 1].set_in_keys(hooks[i].out_keys)
self.domains = hooks[-1].domains
super().__init__(*hooks, **kwargs)


Expand All @@ -326,7 +327,7 @@ class FeaturesAndLogitsHook(FeaturesChainHook):

def __init__(
self,
domains: List[str] = None,
domains: List[str] = ("src", "target"),
detach_features: bool = False,
detach_logits: bool = False,
other_hooks: List[BaseHook] = None,
Expand All @@ -343,6 +344,7 @@ def __init__(
other_hooks: A list of hooks that will be called after
the features and logits hooks.
"""
self.domains = domains
features_hook = FeaturesHook(detach=detach_features, domains=domains)
logits_hook = LogitsHook(detach=detach_logits, domains=domains)
other_hooks = c_f.default(other_hooks, [])
Expand Down
2 changes: 1 addition & 1 deletion tests/adapters/run_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# log files should be a mapping from csv file name, to number of columns in file
def run_adapter(cls, test_folder, adapter, log_files=None, inference_fn=None):
checkpoint_fn = CheckpointFnCreator(dirname=test_folder)
checkpoint_fn = CheckpointFnCreator(dirname=test_folder, require_empty=False)
logger = IgniteRecordKeeperLogger(folder=test_folder)
datasets = get_datasets()
validator = ScoreHistory(EntropyValidator())
Expand Down
34 changes: 18 additions & 16 deletions tests/adapters/test_running.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def gan_log_files():
"total",
"g_src_domain_loss",
"g_target_domain_loss",
"c_loss",
"src_c_loss",
},
"engine_output_d_loss": {
"total",
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_aligner(self):
"optimizers_C_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
"features_confusion_loss",
"logits_confusion_loss",
},
Expand All @@ -172,7 +172,9 @@ def test_aligner(self):
def test_cdan(self):
models = get_gcd()
misc = Misc({"feature_combiner": RandomizedDotProduct([512, 10], 512)})
g_weighter = MeanWeighter(weights={"g_target_domain_loss": 0.5, "c_loss": 0.1})
g_weighter = MeanWeighter(
weights={"g_target_domain_loss": 0.5, "src_c_loss": 0.1}
)
adapter = CDAN(models=models, misc=misc, hook_kwargs={"g_weighter": g_weighter})
self.assertTrue(isinstance(adapter.hook, CDANHook))
log_files = gan_log_files()
Expand All @@ -186,7 +188,7 @@ def test_cdan(self):
},
"hook_8c2a74151317b9315573314fafc0d8ad6e12f72a84433739f6f0762a4ca11ab0_weights": {
"g_target_domain_loss",
"c_loss",
"src_c_loss",
},
}
)
Expand All @@ -204,7 +206,7 @@ def test_classifier(self):
"optimizers_C_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
},
"hook_ClassifierHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
"scale"
Expand All @@ -224,7 +226,7 @@ def test_dann(self):
"optimizers_D_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
"src_domain_loss",
"target_domain_loss",
},
Expand Down Expand Up @@ -267,7 +269,7 @@ def test_finetuner(self):
"optimizers_C_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
},
"hook_FinetunerHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
"scale"
Expand Down Expand Up @@ -305,7 +307,7 @@ def test_joint_aligner(self):
"optimizers_C_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
"joint_confusion_loss",
},
"hook_AlignerPlusCHook_hook_ChainHook_hooks0_OptimizerHook_weighter_MeanWeighter": {
Expand All @@ -328,7 +330,7 @@ def test_gvb(self):
"optimizers_D_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
"src_domain_loss",
"target_domain_loss",
"g_src_bridge_loss",
Expand Down Expand Up @@ -361,13 +363,13 @@ def test_mcd(self):
"optimizers_C_Adam": {"lr"},
"engine_output_x_loss": {
"total",
"c_loss0",
"c_loss1",
"src_c_loss0",
"src_c_loss1",
},
"engine_output_y_loss": {
"total",
"c_loss0",
"c_loss1",
"src_c_loss0",
"src_c_loss1",
"discrepancy_loss",
},
"engine_output_z_loss": {"total", "discrepancy_loss"},
Expand Down Expand Up @@ -399,7 +401,7 @@ def test_rtn(self):
"optimizers_residual_model_Adam": {"lr"},
"engine_output_total_loss": {
"total",
"c_loss",
"src_c_loss",
"entropy_loss",
"features_confusion_loss",
},
Expand All @@ -424,8 +426,8 @@ def test_symnets(self):
"optimizers_G_Adam": {"lr"},
"optimizers_C_Adam": {"lr"},
"engine_output_c_loss": {
"c_loss0",
"c_loss1",
"src_c_loss0",
"src_c_loss1",
"c_symnets_src_domain_loss_0",
"c_symnets_target_domain_loss_1",
"total",
Expand Down
6 changes: 3 additions & 3 deletions tests/hooks/test_aligners.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_aligner_plus_classifier_hook(self):
)

loss_keys = {
"c_loss",
"src_c_loss",
"total",
}

Expand Down Expand Up @@ -120,15 +120,15 @@ def test_aligner_plus_classifier_hook(self):
[F.softmax(target_logits, dim=1), target_features],
)
total_loss = (f_loss + c_loss) / 2
correct_losses = [c_loss, f_loss, total_loss]
correct_losses = [f_loss, c_loss, total_loss]
else:
f_loss = loss_fn()(src_features, target_features)
l_loss = loss_fn()(
F.softmax(src_logits, dim=1), F.softmax(target_logits, dim=1)
)

total_loss = (f_loss + l_loss + c_loss) / 3
correct_losses = [c_loss, f_loss, l_loss, total_loss]
correct_losses = [f_loss, l_loss, c_loss, total_loss]

computed_losses = [
losses["total_loss"][k] for k in sorted(list(loss_keys))
Expand Down
10 changes: 7 additions & 3 deletions tests/hooks/test_cdan.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_cdan_hook(self):
g_loss_keys = {
"g_src_domain_loss",
"g_target_domain_loss",
"c_loss",
"src_c_loss",
"total",
}

Expand Down Expand Up @@ -349,14 +349,18 @@ def test_cdan_hook(self):
g_losses["g_target_domain_loss"] * target_entropy_weights
)

g_losses["c_loss"] = torch.nn.functional.cross_entropy(
g_losses["src_c_loss"] = torch.nn.functional.cross_entropy(
c_logits[:bs], src_labels
)

self.assertTrue(
all(
np.isclose(losses["g_loss"][k], g_losses[k].item())
for k in ["g_src_domain_loss", "g_target_domain_loss", "c_loss"]
for k in [
"g_src_domain_loss",
"g_target_domain_loss",
"src_c_loss",
]
)
)
g_losses = list(g_losses.values())
Expand Down
60 changes: 47 additions & 13 deletions tests/hooks/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,53 @@ def test_softmax_hook(self):

def test_closs_hook(self):
torch.manual_seed(24242)
src_imgs = torch.randn(100, 32)
target_imgs = torch.randn(100, 32)
src_labels = torch.randint(0, 10, size=(100,))
target_labels = torch.randint(0, 10, size=(100,))
G = Net(32, 16)
C = Net(16, 10)

for detach_features in [True, False]:
h = CLossHook(detach_features=detach_features)
src_imgs = torch.randn(100, 32)
target_imgs = torch.randn(100, 32)
src_labels = torch.randint(0, 10, size=(100,))
G = Net(32, 16)
C = Net(16, 10)
outputs, losses = h(locals())
assertRequiresGrad(self, outputs)
base_key = "src_imgs_features"
if detach_features:
base_key += "_detached"
self.assertTrue(outputs.keys() == {base_key, f"{base_key}_logits"})
for domains in [None, ("src",), ("target",), ("src", "target")]:
if domains is None:
h = CLossHook(detach_features=detach_features)
else:
h = CLossHook(detach_features=detach_features, domains=domains)
outputs, losses = h(locals())
assertRequiresGrad(self, outputs)
base_keys = (
[f"{d}_imgs_features" for d in domains]
if domains
else ["src_imgs_features"]
)
if detach_features:
base_keys = [f"{x}_detached" for x in base_keys]
logit_keys = [f"{x}_logits" for x in base_keys]
self.assertTrue(outputs.keys() == {*base_keys, *logit_keys})

correct_loss_fn = torch.nn.functional.cross_entropy
for k, v in losses.items():
if k.startswith("src"):
self.assertTrue(
torch.equal(
v,
correct_loss_fn(
C(G(src_imgs)), src_labels, reduction="none"
),
)
)
elif k.startswith("target"):
self.assertTrue(
torch.equal(
v,
correct_loss_fn(
C(G(target_imgs)), target_labels, reduction="none"
),
)
)
else:
raise KeyError

def test_classifier_hook(self):
torch.manual_seed(53430)
Expand Down Expand Up @@ -77,4 +111,4 @@ def test_classifier_hook(self):
_, losses = h(
{"G": G, "C": C, "src_imgs": src_imgs, "src_labels": src_labels}
)
self.assertTrue(np.isclose(losses["total_loss"]["c_loss"], correct))
self.assertTrue(np.isclose(losses["total_loss"]["src_c_loss"], correct))
4 changes: 2 additions & 2 deletions tests/hooks/test_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_dann(self):
loss_keys = {
"src_domain_loss",
"target_domain_loss",
"c_loss",
"src_c_loss",
"total",
}

Expand Down Expand Up @@ -254,7 +254,7 @@ def test_dann(self):

c_loss = F.cross_entropy(src_logits, src_labels)
self.assertTrue(
np.isclose(c_loss.item(), losses["total_loss"]["c_loss"])
np.isclose(c_loss.item(), losses["total_loss"]["src_c_loss"])
)

total_loss = [src_domain_loss, target_domain_loss, c_loss]
Expand Down
2 changes: 1 addition & 1 deletion tests/hooks/test_domain_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_domain_confusion_hook(self):
)
self.assertTrue(
losses["g_loss"].keys()
== {"g_src_domain_loss", "g_target_domain_loss", "c_loss", "total"}
== {"g_src_domain_loss", "g_target_domain_loss", "src_c_loss", "total"}
)
self.assertTrue(
losses["d_loss"].keys()
Expand Down
Loading

0 comments on commit 3cef36f

Please sign in to comment.