From ffe2845c73e8a6bb9ca4181b4e50ea91e0c48ea6 Mon Sep 17 00:00:00 2001 From: cetaceanw Date: Wed, 22 Mar 2023 15:50:34 +0100 Subject: [PATCH] Could now save images on local device and display the important ones. Could do coarse search. Avoid repeat summarizing papers to save tokens. Always overwrite existing files. --- chat_paper.py | 75 +++++++++++++++++++++++++++++++---------- get_paper_from_pdf.py | 77 +++++++++++++++++++++++++++---------------- 2 files changed, 105 insertions(+), 47 deletions(-) diff --git a/chat_paper.py b/chat_paper.py index d05e59f..f711c39 100644 --- a/chat_paper.py +++ b/chat_paper.py @@ -39,10 +39,10 @@ def __init__(self, key_word, query, filter_keys, self.chat_api_list = [api.strip() for api in self.chat_api_list if len(api) > 5] self.cur_api = 0 self.file_format = args.file_format - if args.save_image: - self.gitee_key = self.config.get('Gitee', 'api') - else: - self.gitee_key = '' + if args.save_image == 'gitee': + self.image_path = self.config.get('Gitee', 'api') + elif args.save_image == 'local': + self.image_path = root_path + 'export/' + 'attachments/' self.max_token_num = 4096 self.encoding = tiktoken.get_encoding("gpt2") @@ -64,16 +64,26 @@ def filter_arxiv(self, max_results=30): filter_keys = self.filter_keys print("filter_keys:", self.filter_keys) - # 确保每个关键词都能在摘要中找到,才算是目标论文 - for index, result in enumerate(search.results()): - abs_text = result.summary.replace('-\n', '-').replace('\n', ' ') - meet_num = 0 - for f_key in filter_keys.split(" "): - if f_key.lower() in abs_text.lower(): - meet_num += 1 - if meet_num == len(filter_keys.split(" ")): - filter_results.append(result) + # Exact match: 确保每个关键词都能在摘要中找到,才算是目标论文 + if args.coarse == False: + print("Exact match") + for index, result in enumerate(search.results()): + abs_text = result.summary.replace('-\n', '-').replace('\n', ' ') + meet_num = 0 + for f_key in filter_keys.split(" "): + if f_key.lower() in abs_text.lower(): + meet_num += 1 + if meet_num == len(filter_keys.split(" ")): + filter_results.append(result) # break + else: + print("Coarse match") + for index, result in enumerate(search.results()): + abs_text = result.summary.replace('-\n', '-').replace('\n', ' ') + for f_key in filter_keys.split(" "): + if f_key.lower() in abs_text.lower(): + filter_results.append(result) + break print("筛选后剩下的论文数量:") print("filter_results:", len(filter_results)) print("filter_papers:") @@ -103,6 +113,13 @@ def download_pdf(self, filter_results): try: title_str = self.validateTitle(result.title) pdf_name = title_str+'.pdf' + # Try to avoid repeating papers + for dir in os.listdir(os.path.join(self.root_path,'pdf_files')): + for file in os.listdir(os.path.join(self.root_path, 'pdf_files', dir)): + if pdf_name == file and args.repeat==False: + raise Exception('\033[91m'+pdf_name+" already exists, no summary will be made to save your tokens. If you insist to summary, pass --repeat True to force summarizing"+'\033[0m') + elif pdf_name == file and args.repeat==True: + print('\033[93m'+pdf_name+" already exists, repeat summarizing anyway"+'\033[0m') # result.download_pdf(path, filename=pdf_name) self.try_download_pdf(result, path, pdf_name) paper_path = os.path.join(path, pdf_name) @@ -143,7 +160,7 @@ def upload_gitee(self, image_path, image_name='', ext='png'): path = image_name+ '-' +date_str payload = { - "access_token": self.gitee_key, + "access_token": self.config.get('Gitee', 'api'), "owner": self.config.get('Gitee', 'owner'), "repo": self.config.get('Gitee', 'repo'), "path": self.config.get('Gitee', 'path'), @@ -244,7 +261,27 @@ def summary_with_chat(self, paper_list): chat_conclusion_text = self.chat_conclusion(text=text, conclusion_prompt_token=conclusion_prompt_token) htmls.append(chat_conclusion_text) htmls.append("\n"*4) - + + # 第四步补充材料,实验/结果部分前的图片比较有价值 + htmls.append("**Supplement Materials:**\n") + img_list, ext = paper.get_image_path(self.image_path) + if img_list is None or args.save_image == '': + pass + elif args.save_image == 'local': + for i_page in range(len(img_list)): + for i_image in range(len(img_list[i_page])): + htmls.append("\n") + htmls.append("![Fig]("+img_list[i_page][i_image].replace(' ', '%20').replace('./export', '.')+")") + htmls.append("\n") + elif args.save_image == 'gitee': + for i_page in range(len(img_list)): + for i_image in range(len(img_list[i_page])): + image_title = self.validateTitle(paper.title) + image_url = self.upload_gitee(image_path=img_list[i_page][i_image], image_name=image_title, ext=ext[i_page][i_image]) + htmls.append("\n") + htmls.append("![Fig]("+image_url+")") + htmls.append("\n") + # # 整合成一个文件,打包保存下来。 date_str = str(datetime.datetime.now())[:13].replace(' ', '-') try: @@ -252,7 +289,7 @@ def summary_with_chat(self, paper_list): os.makedirs(export_path) except: pass - mode = 'w' if paper_index == 0 else 'a' + mode = 'w' # Don't understand here, we should always overwrite file_name = os.path.join(export_path, date_str+'-'+self.validateTitle(paper.title[:80])+"."+self.file_format) self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode) @@ -464,11 +501,13 @@ def main(args): parser.add_argument("--pdf_path", type=str, default='', help="if none, the bot will download from arxiv with query") parser.add_argument("--query", type=str, default='all: ChatGPT robot', help="the query string, ti: xx, au: xx, all: xx,") parser.add_argument("--key_word", type=str, default='reinforcement learning', help="the key word of user research fields") - parser.add_argument("--filter_keys", type=str, default='ChatGPT robot', help="the filter key words, 摘要中每个单词都得有,才会被筛选为目标论文") + parser.add_argument("--filter_keys", type=str, default='ChatGPT robot', help="the filter key words, 摘要中每个单词都得有,才会被筛选为目标论文, separated by space") + parser.add_argument("--coarse", action='store_true', help="if every key word needs to be matched") + parser.add_argument("--repeat", action='store_true', help="if pdf files already exist, don't summarize again to save tokens") parser.add_argument("--max_results", type=int, default=1, help="the maximum number of results") # arxiv.SortCriterion.Relevance parser.add_argument("--sort", type=str, default="Relevance", help="another is LastUpdatedDate") - parser.add_argument("--save_image", default=False, help="save image? It takes a minute or two to save a picture! But pretty") + parser.add_argument("--save_image", type=str, default='', help="save image? It takes a minute or two to save a picture! But pretty") parser.add_argument("--file_format", type=str, default='md', help="导出的文件格式,如果存图片的话,最好是md,如果不是的话,txt的不会乱") parser.add_argument("--language", type=str, default='zh', help="The other output lauguage is English, is en") diff --git a/get_paper_from_pdf.py b/get_paper_from_pdf.py index 2117a58..9ccff5e 100644 --- a/get_paper_from_pdf.py +++ b/get_paper_from_pdf.py @@ -44,23 +44,45 @@ def get_paper_info(self): def get_image_path(self, image_path=''): """ - 将PDF中的第一张图保存到image.png里面,存到本地目录,返回文件名称,供gitee读取 - :param filename: 图片所在路径,"C:\\Users\\Administrator\\Desktop\\nwd.pdf" - :param image_path: 图片提取后的保存路径 - :return: + 将pdf中Experiment/Evaluation前的图片均保存下来, 一般method部分会有理论图解 + parameter: + - image_path: path in which the imgs are saved + + return: + - img_path: list of list, path of images in each page + - ext: the associated extension """ + # Create image folders + try: + os.makedirs(image_path) + except: + pass + # open file max_size = 0 image_list = [] + ext = [] + stop_index = 0 + exp_key = ["Materials and Methods", "Experiment Settings", + 'Experiment', "Experimental Results", "Evaluation", "Experiments", + "Results", 'Findings', 'Data Analysis'] + + for key in self.section_page_dict.keys(): + if key in exp_key: + stop_index = self.section_page_dict[key] + break + with fitz.Document(self.path) as my_pdf_file: - # 遍历所有页面 - for page_number in range(1, len(my_pdf_file) + 1): + # 遍历实验前的所有页面 + for page_number in range(1, stop_index+1): # 查看独立页面 page = my_pdf_file[page_number - 1] # 查看当前页所有图片 images = page.get_images() # 遍历当前页面所有图片 - for image_number, image in enumerate(page.get_images(), start=1): + image_in_page = [] + ext_in_page = [] + for image_number, image in enumerate(images, start=1): # 访问图片xref xref_value = image[0] # 提取图片信息 @@ -68,33 +90,30 @@ def get_image_path(self, image_path=''): # 访问图片 image_bytes = base_image["image"] # 获取图片扩展名 - ext = base_image["ext"] + ext_in_page.append(base_image["ext"]) # 加载图片 image = Image.open(io.BytesIO(image_bytes)) image_size = image.size[0] * image.size[1] if image_size > max_size: max_size = image_size - image_list.append(image) - for image in image_list: - image_size = image.size[0] * image.size[1] - if image_size == max_size: - image_name = f"image.{ext}" - im_path = os.path.join(image_path, image_name) - print("im_path:", im_path) - - max_pix = 480 - origin_min_pix = min(image.size[0], image.size[1]) - - if image.size[0] > image.size[1]: - min_pix = int(image.size[1] * (max_pix/image.size[0])) - newsize = (max_pix, min_pix) - else: - min_pix = int(image.size[0] * (max_pix/image.size[1])) - newsize = (min_pix, max_pix) - image = image.resize(newsize) - - image.save(open(im_path, "wb")) - return im_path, ext + image_in_page.append(image) + image_list.append(image_in_page) + ext.append(ext_in_page) + + img_path = [] + for i_page in range(len(image_list)): + im_page_path = [] + for i_image in range(len(image_list[i_page])): + image = image_list[i_page][i_image] + image_size = image.size[0] * image.size[1] + image_name = self.title + f"_{i_page}" + f"_{i_image}" + f".{ext[i_page][i_image]}" + path = os.path.join(image_path, image_name) + im_page_path.append(path) + image.save(open(path, "wb")) + img_path.append(im_page_path) + + if len(img_path) != 0: + return img_path, ext return None, None # 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表