Skip to content

Commit

Permalink
add spell check based on to do lists
Browse files Browse the repository at this point in the history
  • Loading branch information
mrzaizai2k committed Jul 11, 2024
1 parent a37ea45 commit f158260
Show file tree
Hide file tree
Showing 7 changed files with 1,102 additions and 246 deletions.
443 changes: 443 additions & 0 deletions notebook/Final_Version.ipynb

Large diffs are not rendered by default.

424 changes: 310 additions & 114 deletions notebook/langchain_test.ipynb

Large diffs are not rendered by default.

77 changes: 72 additions & 5 deletions src/Microsofttodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from todocli.utils.datetime_util import datetime_to_api_timestamp
import json
from requests_oauthlib import OAuth2Session


# Oauth settings
import os
import pickle
Expand Down Expand Up @@ -203,7 +205,11 @@ def rename_list(self, old_title: str, new_title: str):
return True if response.ok else response.raise_for_status()


def get_tasks(self, list_name: str = None, list_id: str = None, num_tasks: int = 100):
def get_tasks(self, list_name: str = None,
list_id: str = None,
num_tasks: int = 100,
get_completed:bool = False):

assert (list_name is not None) or (
list_id is not None
), "You must provide list_name or list_id"
Expand All @@ -212,9 +218,15 @@ def get_tasks(self, list_name: str = None, list_id: str = None, num_tasks: int =
if list_id is None:
list_id = self.get_list_id_by_name(list_name)

endpoint = (
f"{BASE_URL}/{list_id}/tasks?$filter=status ne 'completed'&$top={num_tasks}"
)
if get_completed:
endpoint = (
f"{BASE_URL}/{list_id}/tasks?$top={num_tasks}"
)
else:
endpoint = (
f"{BASE_URL}/{list_id}/tasks?$filter=status ne 'completed'&$top={num_tasks}"
)

session = get_oauth_session()
response = session.get(endpoint)
response_value = self.parse_response(response)
Expand Down Expand Up @@ -351,4 +363,59 @@ def create_task(self,
}
session = get_oauth_session()
response = session.post(endpoint, json=request_body)
return True if response.ok else response.raise_for_status()
return True if response.ok else response.raise_for_status()


def _filter_task_titles(self, task_titles):
"""
Filters the task titles based on specific criteria.
"""
import re
url_pattern = re.compile(r'(https?://|file://)')
key_pattern = re.compile(r'KEY|password|Error|\\\\DESKTOP')
number_pattern = re.compile(r'^\d+$')
number_comma_pattern = re.compile(r'^\d+,\d+,\d+$')

filtered_names = []
for name in task_titles:
if url_pattern.search(name):
continue
if key_pattern.search(name):
continue
if len(name.split()) > 20:
continue
if len(name) == 1:
continue
if number_pattern.fullmatch(name):
continue
if number_comma_pattern.fullmatch(name):
continue
filtered_names.append(name)
return filtered_names

def get_all_tasks(self, num_tasks:int = 100, get_completed:bool = False):
"""
Get all tasks in the app
Might be used to create a tasks dictionary
"""
task_titles =[]
todo_lists = todo.get_lists()
for task_list in todo_lists:
list_name = task_list['displayName']
try:
todo_tasks = todo.get_tasks(list_name=list_name, num_tasks=num_tasks, get_completed=get_completed)
names = [item['title'] for item in todo_tasks]
task_titles.extend(names)
except Exception as e:
print(f"Error: {e}")
print(f"Task list: {task_list}")
task_titles = self._filter_task_titles(task_titles)
return task_titles

if __name__ == "__main__":
todo = MicrosoftToDo()
todo_tasks = todo.get_tasks(list_name='Tasks')
todo_tasks[0:2]
names = [item['title'] for item in todo_tasks]
print (names)

4 changes: 2 additions & 2 deletions src/Utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import logging
from logging.handlers import TimedRotatingFileHandler

def create_logger(logfile='logs/facesvc.log'):
def create_logger(logfile='logging/stock_bot.log'):
logger = logging.getLogger()
# Set the logging level
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
rotate_handler = TimedRotatingFileHandler(filename=logfile, when="midnight", backupCount=30)
rotate_handler = TimedRotatingFileHandler(filename=logfile, when="midnight", backupCount=5)
rotate_handler.setLevel(logging.DEBUG)
rotate_handler.suffix = "%Y%m%d"

Expand Down
136 changes: 127 additions & 9 deletions src/Utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,131 @@
from dotenv import load_dotenv
load_dotenv()

import cv2
import subprocess
import schedule
import time
import yaml
from functools import wraps
import torch
from pydub import AudioSegment

import re
import phunspell
from rank_bm25 import BM25Okapi
import editdistance
from typing import Literal

from src.Microsofttodo import *

from src.Utils.logger import create_logger
logger = create_logger()


class SpellCheck:
def __init__(self, history_tasks_path: str, n_grams: int):
self.history_tasks_path = history_tasks_path
self.n_grams = n_grams
self.history_tasks = self.load_history_tasks(history_tasks_path)

def preprocess_text(self, input_string):
return re.sub(r'[^\w\s\']', '', input_string)

def load_history_tasks(self, file_path:str):
with open(file_path, 'r') as file:
tasks = file.readlines()
tasks = [task.strip() for task in tasks]
return tasks

def generate_ngrams(self, text:str, n_grams:int=1):
words = text.split()
ngrams = []
for i in range(1, n_grams + 1):
ngrams.extend([' '.join(words[j:j + i]) for j in range(len(words) - i + 1)])
return ngrams

def get_best_match_bm25(self, token, history_tokens_flat, bm25,
verbose:bool = False):
token_candidates = bm25.get_top_n(token, history_tokens_flat, n=5)
if verbose:
print('wrong token', token)
print('token candidates', token_candidates[:5])
return token_candidates[0] if token_candidates else token

def get_best_match_editdistance(self, token, history_tokens_flat,
verbose:bool = False):
distances = [(history_token, editdistance.eval(token, history_token)) for history_token in history_tokens_flat]
distances.sort(key=lambda x: x[1])
if verbose:
print('wrong token', token)
print('token candidates', distances[:5]) # Print top 5 candidates based on edit distance
return distances[0][0] if distances else token

def get_best_match_bm25_editdistance(self, token, history_tokens_flat, bm25,
verbose:bool = False, top_n:int = 10):
token_candidates = bm25.get_top_n(token, history_tokens_flat, top_n)
if verbose:
print('wrong token', token)
print('BM25 top 10 candidates', token_candidates)

if not token_candidates:
return token

distances = [(candidate, editdistance.eval(token, candidate)) for candidate in token_candidates]
distances.sort(key=lambda x: x[1])
if verbose:
print('Edit Distance candidates', distances[:5]) # Print top 5 candidates based on edit distance

return distances[0][0] if distances else token

def spell_check_and_correct(self, input_string,
method: Literal["BM25", "editdistance", "BM25_EditDistance"],
loc_lang: Literal['en_US', 'vi_VN'] = 'en_US',
verbose: bool = False):

input_tokens = self.preprocess_text(input_string.lower())
corrected_tokens = []

input_ngrams = self.generate_ngrams(input_tokens, self.n_grams)

history_tasks_tokens = [self.generate_ngrams(self.preprocess_text(task.lower()), self.n_grams) for task in self.history_tasks]
history_tokens_flat = [token for sublist in history_tasks_tokens for token in sublist]

pspell = phunspell.Phunspell(loc_lang=loc_lang)

if method == "BM25" or method == "BM25_EditDistance":
bm25 = BM25Okapi(history_tokens_flat)

for token in input_ngrams:
if token in pspell.lookup_list(token.split(" ")):
if method == "BM25":
best_match = self.get_best_match_bm25(token, history_tokens_flat, bm25, verbose)
elif method == "editdistance":
best_match = self.get_best_match_editdistance(token, history_tokens_flat, verbose)
elif method == "BM25_EditDistance":
best_match = self.get_best_match_bm25_editdistance(token, history_tokens_flat, bm25, verbose)
corrected_tokens.append(best_match)
else:
corrected_tokens.append(token)

corrected_string = input_string
for original_token, corrected_token in zip(input_tokens.split(), corrected_tokens):
if original_token != corrected_token:
corrected_string = re.sub(r'\b{}\b'.format(re.escape(original_token)), corrected_token, corrected_string, count=1, flags=re.IGNORECASE)

corrected_string = corrected_string.capitalize()
return corrected_string


def convert_m4a_to_mp3(m4a_file_path:str, mp3_file_path:str):
# Load the .m4a file
audio = AudioSegment.from_file(m4a_file_path, format="m4a")

# Export as .mp3 file
audio.export(mp3_file_path, format="mp3")
print(f"Conversion complete: {m4a_file_path} to {mp3_file_path}")


def timeit(func):
def wrapper(*args, **kwargs):
start_time = time.time()
Expand Down Expand Up @@ -234,17 +351,18 @@ def get_all_watchlist(self)-> list:

def main():
print('Hi')
capture_image_from_camera()
# check_path("data/data1")
# check_path("data/data2/note.txt")
user_db = UserDatabase()
data_config_path = 'config/config.yaml'
with open(data_config_path, 'r') as file:
data = yaml.safe_load(file)

watchlist = data.get('my_watchlist', [])
USER_ID = os.getenv('USER_ID')
user_db.save_watch_list(user_id=USER_ID, watch_list=watchlist)
watch_list = user_db.get_watch_list(user_id=USER_ID)
# user_db = UserDatabase()
# data_config_path = 'config/config.yaml'
# with open(data_config_path, 'r') as file:
# data = yaml.safe_load(file)

# watchlist = data.get('my_watchlist', [])
# USER_ID = os.getenv('USER_ID')
# user_db.save_watch_list(user_id=USER_ID, watch_list=watchlist)
# watch_list = user_db.get_watch_list(user_id=USER_ID)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions src/stock_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def masterquest(message):
# print(f'Result: {response.json()}')
logger.debug(msg=f"Result: {response.json()}")
bot.reply_to(message, f"The answer from {response.json()['model_type']}: \n{response.json()['result']}")
bot.send_message(message.chat.id, f"The source: {response.json()['source_documents']}")
# bot.send_message(message.chat.id, f"The source: {response.json()['source_documents'][0]}")

except Exception as e:
# print(f'Error: {e}')
logger.debug(msg=f"Error on LLM and RAG system: {e}")
Expand Down Expand Up @@ -359,7 +360,6 @@ def process_remove_stock(message):
logger.debug(msg = f"{symbol} not found in your watchlist.")



@bot.message_handler(commands=['remote'])
def open_vscode_tunnel(message):
if not validate_mrzaizai2k_user(message.chat.id):
Expand Down
Loading

0 comments on commit f158260

Please sign in to comment.