diff --git a/LLMCode/utils/completion.py b/LLMCode/utils/completion.py index 55664fe..6165120 100644 --- a/LLMCode/utils/completion.py +++ b/LLMCode/utils/completion.py @@ -1,39 +1,40 @@ -import multiprocessing +import threading import openai import LLMCode.cfg.completion_params as completion_params from . import ANSI_CODE from .logger import LOGGER -def run_with_timeout(target_function, args=None, timeout=30): - pool = multiprocessing.Pool(processes=1) - if args is None: - result = pool.apply_async(target_function) - else: - result = pool.apply_async(target_function, args) - try: - result_value = result.get(timeout=timeout) - return result_value - except Exception as e: - if isinstance(e, multiprocessing.TimeoutError): - LOGGER.info( - "%s\r⚠ The completion could not be done. %s response lasted more than %s seconds, which is the limit.", - ANSI_CODE["yellow"], - target_function, - timeout, - ) - else: - LOGGER.info( - "%s\r⚠ The completion could not be done. %s raised the exception %s", - ANSI_CODE["yellow"], - target_function, - e, - ) - pool.terminate() +def run_with_timeout(target_function, args=(), kwargs={}, timeout=30): + result = [None] # A mutable container to store the function result + exception = [None] # A container to capture exceptions + + def target_wrapper(): + try: + result[0] = target_function(*args, **kwargs) + except Exception as e: + exception[0] = e + + thread = threading.Thread(target=target_wrapper) + thread.start() + thread.join(timeout) + if thread.is_alive(): + LOGGER.info( + "%s⚠ The completion could not be done. %s response lasted more than %s seconds, which is the limit.", + ANSI_CODE["yellow"], + target_function.__name__, + timeout, + ) + return None + if exception[0]: + LOGGER.info( + "%s⚠ The completion could not be done. %s raised the exception %s", + ANSI_CODE["yellow"], + target_function.__name__, + exception[0], + ) return None - finally: - pool.close() - pool.join() + return result[0] # Completion using openai API diff --git a/LLMCode/utils/document.py b/LLMCode/utils/document.py index 18dbeec..f0d33fe 100644 --- a/LLMCode/utils/document.py +++ b/LLMCode/utils/document.py @@ -137,5 +137,5 @@ def add_msg(msg, where, element, script_content): script_content, ) pass - with open(script, "w") as python_file: + with open(script, "w", encoding="utf-8") as python_file: python_file.write(script_content) diff --git a/requirements.txt b/requirements.txt index f0dd0ae..360d2ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -openai \ No newline at end of file +openai==0.28 \ No newline at end of file