diff --git a/docs/docs/gpt-researcher/llms.md b/docs/docs/gpt-researcher/llms.md index fb8909e6e..13fcd7b15 100644 --- a/docs/docs/gpt-researcher/llms.md +++ b/docs/docs/gpt-researcher/llms.md @@ -46,9 +46,18 @@ OPENAI_EMBEDDING_MODEL="custom_model" ### Azure OpenAI +See also the documentation in the Langchain [Azure OpenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.azure.AzureChatOpenAI.html) page. + +On Azure OpenAI you will need to create deployments for each model you want to use. Please also specify the model names/deployment names in your `.env` file: + ```bash EMBEDDING_PROVIDER="azure_openai" AZURE_OPENAI_API_KEY="Your key" +AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" +OPENAI_API_VERSION="2024-05-01-preview" +FAST_LLM_MODEL="gpt-4o-mini" +DEFAULT_LLM_MODEL="gpt-4o-mini" +SMART_LLM_MODEL="gpt-4o" ``` diff --git a/gpt_researcher/config/config.py b/gpt_researcher/config/config.py index b9dccdb7e..5967fd402 100644 --- a/gpt_researcher/config/config.py +++ b/gpt_researcher/config/config.py @@ -8,34 +8,43 @@ class Config: def __init__(self, config_file: str = None): """Initialize the config class.""" - self.config_file = os.path.expanduser(config_file) if config_file else os.getenv('CONFIG_FILE') - self.retrievers = self.parse_retrievers(os.getenv('RETRIEVER', "tavily")) - self.embedding_provider = os.getenv('EMBEDDING_PROVIDER', 'openai') - self.similarity_threshold = int(os.getenv('SIMILARITY_THRESHOLD', 0.38)) - self.llm_provider = os.getenv('LLM_PROVIDER', "openai") - self.ollama_base_url = os.getenv('OLLAMA_BASE_URL', None) - self.llm_model = "gpt-4o-mini" - self.fast_llm_model = os.getenv('FAST_LLM_MODEL', "gpt-4o-mini") - self.smart_llm_model = os.getenv('SMART_LLM_MODEL', "gpt-4o-2024-08-06") - self.fast_token_limit = int(os.getenv('FAST_TOKEN_LIMIT', 2000)) - self.smart_token_limit = int(os.getenv('SMART_TOKEN_LIMIT', 4000)) - self.browse_chunk_max_length = int(os.getenv('BROWSE_CHUNK_MAX_LENGTH', 8192)) - self.summary_token_limit = int(os.getenv('SUMMARY_TOKEN_LIMIT', 700)) - self.temperature = float(os.getenv('TEMPERATURE', 0.55)) - self.llm_temperature = float(os.getenv('LLM_TEMPERATURE', 0.55)) # Add this line - self.user_agent = os.getenv('USER_AGENT', "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " - "(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0") - self.max_search_results_per_query = int(os.getenv('MAX_SEARCH_RESULTS_PER_QUERY', 5)) - self.memory_backend = os.getenv('MEMORY_BACKEND', "local") - self.total_words = int(os.getenv('TOTAL_WORDS', 1000)) - self.report_format = os.getenv('REPORT_FORMAT', "APA") - self.max_iterations = int(os.getenv('MAX_ITERATIONS', 3)) - self.agent_role = os.getenv('AGENT_ROLE', None) + self.config_file = ( + os.path.expanduser(config_file) if config_file else os.getenv("CONFIG_FILE") + ) + self.retrievers = self.parse_retrievers(os.getenv("RETRIEVER", "tavily")) + self.embedding_provider = os.getenv("EMBEDDING_PROVIDER", "openai") + self.similarity_threshold = int(os.getenv("SIMILARITY_THRESHOLD", 0.38)) + self.llm_provider = os.getenv("LLM_PROVIDER", "openai") + self.ollama_base_url = os.getenv("OLLAMA_BASE_URL", None) + self.llm_model = os.getenv("DEFAULT_LLM_MODEL", "gpt-4o-mini") + self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-4o-mini") + self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4o-2024-08-06") + self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 2000)) + self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 4000)) + self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 8192)) + self.summary_token_limit = int(os.getenv("SUMMARY_TOKEN_LIMIT", 700)) + self.temperature = float(os.getenv("TEMPERATURE", 0.55)) + self.llm_temperature = float( + os.getenv("LLM_TEMPERATURE", 0.55) + ) # Add this line + self.user_agent = os.getenv( + "USER_AGENT", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0", + ) + self.max_search_results_per_query = int( + os.getenv("MAX_SEARCH_RESULTS_PER_QUERY", 5) + ) + self.memory_backend = os.getenv("MEMORY_BACKEND", "local") + self.total_words = int(os.getenv("TOTAL_WORDS", 1000)) + self.report_format = os.getenv("REPORT_FORMAT", "APA") + self.max_iterations = int(os.getenv("MAX_ITERATIONS", 3)) + self.agent_role = os.getenv("AGENT_ROLE", None) self.scraper = os.getenv("SCRAPER", "bs") self.max_subtopics = os.getenv("MAX_SUBTOPICS", 3) self.report_source = os.getenv("REPORT_SOURCE", None) self.doc_path = os.getenv("DOC_PATH", "") - self.llm_kwargs = {} + self.llm_kwargs = {} self.load_config_file() if not hasattr(self, "llm_kwargs"): @@ -47,14 +56,26 @@ def __init__(self, config_file: str = None): def parse_retrievers(self, retriever_str: str): """Parse the retriever string into a list of retrievers and validate them.""" VALID_RETRIEVERS = [ - "arxiv", "bing", "custom", "duckduckgo", "exa", "google", "searx", - "semantic_scholar", "serpapi", "serper", "tavily", "pubmed_central" + "arxiv", + "bing", + "custom", + "duckduckgo", + "exa", + "google", + "searx", + "semantic_scholar", + "serpapi", + "serper", + "tavily", + "pubmed_central", ] - retrievers = [retriever.strip() for retriever in retriever_str.split(',')] + retrievers = [retriever.strip() for retriever in retriever_str.split(",")] invalid_retrievers = [r for r in retrievers if r not in VALID_RETRIEVERS] if invalid_retrievers: - raise ValueError(f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. " - f"Valid options are: {', '.join(VALID_RETRIEVERS)}.") + raise ValueError( + f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. " + f"Valid options are: {', '.join(VALID_RETRIEVERS)}." + ) return retrievers def validate_doc_path(self): diff --git a/main.py b/main.py index 8ee069f18..85490af26 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,10 @@ -from backend.server import app from dotenv import load_dotenv + load_dotenv() +from backend.server import app + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt index c0b3bd19c..85c0f5ebf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ markdown langchain>=0.2,<0.3 langchain_community>=0.2,<0.3 langchain-openai>=0.1,<0.2 +langgraph tiktoken gpt-researcher arxiv