Skip to content

Commit

Permalink
Fix clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Jul 11, 2024
1 parent 6f3a16c commit fe9e2e7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
12 changes: 8 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_model(self, model_name):
if model_name in MODEL_TO_CUDA_DEVICE:
device += ":" + MODEL_TO_CUDA_DEVICE[model_name]
model = mteb.get_model(
model_name,
model_name,
revision=self.model_meta[model_name].get("revision", None),
device=device,
)
Expand Down Expand Up @@ -388,7 +388,8 @@ def sts(self, txt0, txt1, txt2, model_name):

# Update layout
fig.update_layout(
title='Similarity Triangle',
# Do not put title so there is more space for the plot; does not seem to add value anyways
# title='Similarity Triangle',
xaxis=dict(
visible=False,
scaleanchor='y', # Anchor x-axis scale to y-axis
Expand All @@ -399,8 +400,11 @@ def sts(self, txt0, txt1, txt2, model_name):
scaleanchor='x', # Anchor y-axis scale to x-axis
scaleratio=1, # Ensure equal scaling
),
width=600,
height=600,
# Make it auto-resize to fit the screen (important to make single mode take the full width)
# width=1200,
# height=600,
# Add padding instead
margin=dict(l=0, r=0, b=0, t=0),
plot_bgcolor='white'
)

Expand Down
38 changes: 21 additions & 17 deletions ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def vote_last_response_single_clustering(vote_type, state, model_selector, reque
}
store_data_in_hub(data, "clustering_single_choice")

return disable_btns(4, visible=False) + disable_btns(3)
return disable_btns(3)
#return disable_btns(5, visible=False) + disable_btns(3)

def get_conv_log_filename():
t = datetime.datetime.now()
Expand Down Expand Up @@ -323,7 +324,7 @@ def build_side_by_side_ui_anon(models):
)
corpus = gr.Dropdown(
label="Corpus",
choices=["wikipedia", "stackoverflow", "arxiv"],
choices=["wikipedia", "arxiv"], #, "stackexchange"
value="wikipedia",
interactive=True,
show_label=True,
Expand Down Expand Up @@ -537,7 +538,7 @@ def build_side_by_side_ui_named(models):
)
corpus = gr.Dropdown(
label="Corpus",
choices=["wikipedia", "stackoverflow", "arxiv"],
choices=["wikipedia", "arxiv"],
value="wikipedia",
interactive=True,
show_label=True,
Expand Down Expand Up @@ -718,7 +719,7 @@ def build_single_model_ui(models):
)
corpus = gr.Dropdown(
label="Corpus",
choices=["wikipedia", "stackoverflow", "arxiv"],
choices=["wikipedia", "arxiv"],
value="wikipedia",
interactive=True,
show_label=True,
Expand Down Expand Up @@ -892,6 +893,8 @@ def clustering_side_by_side(gen_func, state0, state1, txt, ncluster, ndim, dim_m
ip = get_ip(request)
clustering_logger.info(f"Clustering. ip: {ip}")
start_tstamp = time.time()
# Remove prefixes in case it is already beyoned the 1st round.
model_name0, model_name1 = model_name0.replace("### Model A: ", ""), model_name1.replace("### Model B: ", "")
generated_image0, generated_image1, model_name0, model_name1 = gen_func(state0.prompts, model_name0, model_name1, ncluster, ndim=ndim.split(" ")[0], dim_method=dim_method, clustering_method=clustering_method)
state0.model_name, state1.model_name = model_name0, model_name1

Expand Down Expand Up @@ -927,20 +930,21 @@ def clustering_side_by_side(gen_func, state0, state1, txt, ncluster, ndim, dim_m
store_data_in_hub(data, "clustering_individual")


def clustering(gen_func, state, txt, model_name, ncluster, request: gr.Request):
if not txt: raise gr.Warning("Text cannot be empty.")
def clustering(gen_func, state, txt, ncluster, ndim, dim_method, clustering_method, model_name, request: gr.Request):
if not model_name: raise gr.Warning("Model name cannot be empty.")
if state is None:
state = ClusteringState(model_name)
ip = get_ip(request)
clustering_logger.info(f"Clustering. ip: {ip}")
start_tstamp = time.time()
if "<|SEP|>" in txt:
state.prompts.extend(txt.split("<|SEP|>"))
else:
state.prompts.append(txt)
# txt may be None if only changing the dim
if txt:
if "<|SEP|>" in txt:
state.prompts.extend(txt.split("<|SEP|>"))
else:
state.prompts.append(txt)
state.ncluster = ncluster
generated_img = gen_func(state.prompts, model_name, state.ncluster)
generated_img = gen_func(state.prompts, model_name, state.ncluster, ndim=ndim.split(" ")[0], dim_method=dim_method, clustering_method=clustering_method)
state.model_name = model_name

yield state, generated_img, None
Expand Down Expand Up @@ -1524,7 +1528,7 @@ def build_single_model_ui_clustering(models):
outputs=[textbox, ncluster, send_btn, draw_btn, dim_btn],
).then(
gen_func,
inputs=[state, textbox, ncluster, dim_method, clustering_method, model_selector],
inputs=[state, textbox, ncluster, dim_btn, dim_method, clustering_method, model_selector],
outputs=[state, chatbot, textbox],
api_name="submit_btn_single",
show_progress="full"
Expand All @@ -1544,7 +1548,7 @@ def build_single_model_ui_clustering(models):
outputs=[textbox, ncluster, send_btn, draw_btn, dim_btn],
).then(
gen_func,
inputs=[state, textbox, ncluster, dim_method, clustering_method, model_selector],
inputs=[state, textbox, ncluster, dim_btn, dim_method, clustering_method, model_selector],
outputs=[state, chatbot, textbox],
api_name="submit_btn_single",
show_progress="full"
Expand All @@ -1564,7 +1568,7 @@ def build_single_model_ui_clustering(models):
outputs=[textbox, ncluster, send_btn, draw_btn, dim_btn],
).then(
gen_func,
inputs=[state, textbox, ncluster, dim_method, clustering_method, model_selector],
inputs=[state, textbox, ncluster, dim_btn, dim_method, clustering_method, model_selector],
outputs=[state, chatbot, textbox],
api_name="send_btn_single",
show_progress="full"
Expand All @@ -1577,17 +1581,17 @@ def build_single_model_ui_clustering(models):
upvote_btn.click(
partial(vote_last_response_single_clustering, "upvote"),
inputs=[state, model_selector],
outputs=[send_btn, draw_btn, textbox, ncluster, upvote_btn, downvote_btn, flag_btn]
outputs=[upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
partial(vote_last_response_single_clustering, "downvote"),
inputs=[state, model_selector],
outputs=[send_btn, draw_btn, textbox, ncluster, upvote_btn, downvote_btn, flag_btn]
outputs=[upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
partial(vote_last_response_single_clustering, "flag"),
inputs=[state, model_selector],
outputs=[send_btn, draw_btn, textbox, ncluster, upvote_btn, downvote_btn, flag_btn]
outputs=[upvote_btn, downvote_btn, flag_btn]
)
clear_btn.click(
clear_history_clustering,
Expand Down

0 comments on commit fe9e2e7

Please sign in to comment.