Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr committed Sep 16, 2024
1 parent 5e02abd commit 4debf32
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
18 changes: 13 additions & 5 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,30 @@ def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"math": bool(score == "yes") if score else False}


class CategoryVisionHardPrompt(CategoryHardPrompt):
def __init__(self):
super().__init__()
self.name_tag = "criteria_vision_v0.1"

def _convert_filepath_to_base64(self, filepath):
with open(filepath, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
return base64.b64encode(image_file.read()).decode("utf-8")

def pre_process(self, prompt: str, image_list: list):
# Prompt is a list where the first element is text and the second element is a list of image in base64 format
conv = [{"role": "system", "content": self.sys_prompt}]
single_turn_content_list = []
single_turn_content_list.append({"type": "text", "text": prompt})
for image_url in image_list:
single_turn_content_list.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}"}})

single_turn_content_list.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}"
},
}
)

conv.append({"role": "user", "content": single_turn_content_list})
return conv
return conv
12 changes: 9 additions & 3 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def find_required_tasks(row):
)
]

def aggregate_entire_conversation(conversation, images_dir):

def aggregate_entire_conversation(conversation, images_dir):
final_text_content = ""
final_image_list = []

Expand All @@ -186,12 +187,15 @@ def aggregate_entire_conversation(conversation, images_dir):

return final_text_content, final_image_list


def get_prompt_from_conversation(conversation):
return conversation[0]


def get_image_list_from_conversation(conversation):
return conversation[1]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
Expand Down Expand Up @@ -276,9 +280,11 @@ def get_image_list_from_conversation(conversation):
not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: aggregate_entire_conversation(convo, config["images_dir"])
)

if config["images_dir"]:
not_labeled["image_list"] = not_labeled.prompt.map(get_image_list_from_conversation)
not_labeled["image_list"] = not_labeled.prompt.map(
get_image_list_from_conversation
)
not_labeled = not_labeled[not_labeled.image_list.map(len) > 0]
not_labeled["prompt"] = not_labeled.prompt.map(get_prompt_from_conversation)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])
Expand Down
5 changes: 4 additions & 1 deletion fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,10 @@ def pretty_print_elo_rating(rating):
"long": filter_long_conv,
"chinese": lambda x: x["language"] == "Chinese",
"english": lambda x: x["language"] == "English",
"criteria_vision_v0.1": lambda x: sum(x["category_tag"]["criteria_vision_v0.1"].values()) >= 6,
"criteria_vision_v0.1": lambda x: sum(
x["category_tag"]["criteria_vision_v0.1"].values()
)
>= 6,
}
assert all(
[cat in filter_func_map for cat in args.category]
Expand Down

0 comments on commit 4debf32

Please sign in to comment.