-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdraw.py
109 lines (89 loc) · 3.25 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import pandas as pd
import plotly.graph_objects as go
import wandb
from plotly.subplots import make_subplots
def plot_by_mode(source, adv_mode, output_folder="figs"):
source = source[source['adv_mode'] == adv_mode]
datasets_metrics = [("CUB", "Acc@1"), ("CUB", "ASR@1"), ("AwA", "Acc@1"), ("AwA", "ASR@1")]
base_models = ["vgg16", "resnet50", "vit"]
subplot_titles = []
for base in base_models:
for dataset, metric in datasets_metrics:
subplot_titles.append(f"{base} - {dataset} - {metric}")
fig = make_subplots(
rows=3,
cols=4,
subplot_titles=subplot_titles,
vertical_spacing=0.1,
horizontal_spacing=0.1,
)
fixed_colors = ["#5760A8", "#F6B917", "#DE4849"]
model_colors = {}
legends = set()
for i, base in enumerate(base_models):
for j, (dataset, metric) in enumerate(datasets_metrics):
subset = source[
(source['dataset'] == dataset) &
(source['base'] == base)
]
subset = subset.dropna(subset=[metric])
for idx, run in enumerate(subset['model'].unique()):
data = subset[subset['model'] == run].iloc[0]
if run not in model_colors:
model_colors[run] = fixed_colors[idx % len(fixed_colors)]
color = model_colors[run]
if run not in legends:
legends.add(run)
showlegend = True
else:
showlegend = False
fig.add_trace(
go.Scatter(
x=list(range(5)),
y=data[metric],
mode='lines+markers',
name=run,
legendgroup=run,
marker=dict(color=color),
showlegend=showlegend
),
row=i + 1,
col=j + 1
)
fig.update_layout(
height=1500,
width=2000,
title_text=f"adv_mode = {adv_mode}",
showlegend=True,
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(t=100, b=100, l=100, r=100),
xaxis=dict(showline=True, linecolor='black'),
yaxis=dict(showline=True, linecolor='black'),
)
for i in range(1, 13):
fig.update_layout(
**{
f"xaxis{i}": dict(showline=True, linecolor="black", showgrid=False, mirror=True),
f"yaxis{i}": dict(showline=True, linecolor="black", showgrid=False, mirror=True),
}
)
os.makedirs(output_folder, exist_ok=True)
output_path = os.path.join(output_folder, f"pic_{adv_mode}.png")
fig.write_image(output_path, engine="orca")
fig.show()
wandb.login(key="872d67256b614f408c84eb1138d8a2acd073d911")
api = wandb.Api()
sweep = api.sweep("matrix72c-jesse/RobustCBM/iaykwf0h")
sweep_runs = sweep.runs
tb = []
for r in sweep_runs:
cfg = r.config
cfg.update(r.summary)
tb.append(cfg)
df = pd.DataFrame(tb)
res = df.copy()
res['name'] = res['adv_mode'] + '_' + res['base'] + '_' + res['dataset'] + '_' + res['model']
plot_by_mode(res, adv_mode='adv')
plot_by_mode(res, adv_mode='std')