Skip to content

Commit

Permalink
up speed for summarization model
Browse files Browse the repository at this point in the history
  • Loading branch information
mrzaizai2k committed Jun 10, 2024
1 parent ca5925d commit 01e2937
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/stock_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def ask_for_question(message):
bot.register_next_step_handler(message, masterquest)

def masterquest(message):

query = message.text
masterquest_url = data.get('masterquest_url')
try:
Expand Down
27 changes: 16 additions & 11 deletions src/summarize_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Literal
from transformers import pipeline
from bs4 import BeautifulSoup
from src.Utils.utils import check_path, take_device
from src.Utils.utils import check_path, take_device, timeit

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
Expand All @@ -34,6 +34,7 @@
from datetime import datetime



class SeperateTaskPrompt:
def __init__(self, template_path:str = 'config/seperate_task_template.txt', ):

Expand Down Expand Up @@ -303,31 +304,35 @@ def take_text_from_link(self, news_url:str) -> str :

class NewsSummarizer:
def __init__(self, summarizer = pipeline("summarization",
model="Falconsai/text_summarization", device = take_device()),
model="Falconsai/text_summarization",
torch_dtype=torch.bfloat16,
device = take_device()),
translator = GoogleTranslator(),
max_length:int=230,
max_length:int=200,
min_length:int=30,
):
self.summarizer = summarizer
self.translator = translator
self.max_length = max_length
self.min_length = min_length

def summary_text(self,text:str)->str:
def summary_text(self,text):
'''Summary short text'''
sum_text = self.summarizer(text, max_length=self.max_length,
min_length=self.min_length, do_sample=False)[0]['summary_text']
sum_text= f''
for model_output in self.summarizer(text, batch_size=8, truncation="only_first"):
text = model_output['summary_text']
sum_text += f'\n{text}'
return sum_text

@timeit
def summary_news(self, news:str, chunk_overlap:str = 0)->str:

text_splitter = TokenTextSplitter(chunk_size=self.max_length * 2,
chunk_overlap=chunk_overlap)

trans_news = self.translator.translate(text=news, to_lang='en')
text_chunks = text_splitter.split_text(trans_news)
summary_documents = [self.summary_text(chunk) for chunk in text_chunks]
summary_text = '\n'.join(summary_documents)
summary_text = self.summary_text(text_chunks)

summary_text = self.translator.translate(text=summary_text, to_lang='vi')
return summary_text
Expand Down Expand Up @@ -441,9 +446,9 @@ def _save_summary_data(self, summary_data):
print(f"Summary data saved to {self.summary_news_data_path}")

if __name__ == "__main__":
speech_to_text = SpeechSummaryProcessor(audio_path='sample_voice.m4a')
text = speech_to_text.generate_speech_to_text()
print ('Text', text)
# speech_to_text = SpeechSummaryProcessor(audio_path='sample_voice.m4a')
# text = speech_to_text.generate_speech_to_text()
# print ('Text', text)

symbol = 'SSI'
date_format='year'
Expand Down
2 changes: 1 addition & 1 deletion src/trading_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,4 @@ def scrape_report(report_type):
TRADE_USER= os.getenv('TRADE_USER')
TRADE_PASS= os.getenv('TRADE_PASS')
# scrape_trading_data(user_name=TRADE_USER, password=TRADE_PASS)
scrape_trading_data_async(user_name=TRADE_USER, password=TRADE_PASS)
scrape_trading_data(user_name=TRADE_USER, password=TRADE_PASS)

0 comments on commit 01e2937

Please sign in to comment.