Skip to content

Commit

Permalink
Merge pull request #786 from danieldekay/improved_azure_openai_config
Browse files Browse the repository at this point in the history
Improved azure openai config
  • Loading branch information
assafelovic authored Aug 21, 2024
2 parents ffa9c94 + bb87592 commit e0e8524
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 31 deletions.
9 changes: 9 additions & 0 deletions docs/docs/gpt-researcher/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,18 @@ OPENAI_EMBEDDING_MODEL="custom_model"

### Azure OpenAI

See also the documentation in the Langchain [Azure OpenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.azure.AzureChatOpenAI.html) page.

On Azure OpenAI you will need to create deployments for each model you want to use. Please also specify the model names/deployment names in your `.env` file:

```bash
EMBEDDING_PROVIDER="azure_openai"
AZURE_OPENAI_API_KEY="Your key"
AZURE_OPENAI_ENDPOINT="https://<your-endpoint>.openai.azure.com/"
OPENAI_API_VERSION="2024-05-01-preview"
FAST_LLM_MODEL="gpt-4o-mini"
DEFAULT_LLM_MODEL="gpt-4o-mini"
SMART_LLM_MODEL="gpt-4o"
```


Expand Down
79 changes: 50 additions & 29 deletions gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,43 @@ class Config:

def __init__(self, config_file: str = None):
"""Initialize the config class."""
self.config_file = os.path.expanduser(config_file) if config_file else os.getenv('CONFIG_FILE')
self.retrievers = self.parse_retrievers(os.getenv('RETRIEVER', "tavily"))
self.embedding_provider = os.getenv('EMBEDDING_PROVIDER', 'openai')
self.similarity_threshold = int(os.getenv('SIMILARITY_THRESHOLD', 0.38))
self.llm_provider = os.getenv('LLM_PROVIDER', "openai")
self.ollama_base_url = os.getenv('OLLAMA_BASE_URL', None)
self.llm_model = "gpt-4o-mini"
self.fast_llm_model = os.getenv('FAST_LLM_MODEL', "gpt-4o-mini")
self.smart_llm_model = os.getenv('SMART_LLM_MODEL', "gpt-4o-2024-08-06")
self.fast_token_limit = int(os.getenv('FAST_TOKEN_LIMIT', 2000))
self.smart_token_limit = int(os.getenv('SMART_TOKEN_LIMIT', 4000))
self.browse_chunk_max_length = int(os.getenv('BROWSE_CHUNK_MAX_LENGTH', 8192))
self.summary_token_limit = int(os.getenv('SUMMARY_TOKEN_LIMIT', 700))
self.temperature = float(os.getenv('TEMPERATURE', 0.55))
self.llm_temperature = float(os.getenv('LLM_TEMPERATURE', 0.55)) # Add this line
self.user_agent = os.getenv('USER_AGENT', "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0")
self.max_search_results_per_query = int(os.getenv('MAX_SEARCH_RESULTS_PER_QUERY', 5))
self.memory_backend = os.getenv('MEMORY_BACKEND', "local")
self.total_words = int(os.getenv('TOTAL_WORDS', 1000))
self.report_format = os.getenv('REPORT_FORMAT', "APA")
self.max_iterations = int(os.getenv('MAX_ITERATIONS', 3))
self.agent_role = os.getenv('AGENT_ROLE', None)
self.config_file = (
os.path.expanduser(config_file) if config_file else os.getenv("CONFIG_FILE")
)
self.retrievers = self.parse_retrievers(os.getenv("RETRIEVER", "tavily"))
self.embedding_provider = os.getenv("EMBEDDING_PROVIDER", "openai")
self.similarity_threshold = int(os.getenv("SIMILARITY_THRESHOLD", 0.38))
self.llm_provider = os.getenv("LLM_PROVIDER", "openai")
self.ollama_base_url = os.getenv("OLLAMA_BASE_URL", None)
self.llm_model = os.getenv("DEFAULT_LLM_MODEL", "gpt-4o-mini")
self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-4o-mini")
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4o-2024-08-06")
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 2000))
self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 4000))
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 8192))
self.summary_token_limit = int(os.getenv("SUMMARY_TOKEN_LIMIT", 700))
self.temperature = float(os.getenv("TEMPERATURE", 0.55))
self.llm_temperature = float(
os.getenv("LLM_TEMPERATURE", 0.55)
) # Add this line
self.user_agent = os.getenv(
"USER_AGENT",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
)
self.max_search_results_per_query = int(
os.getenv("MAX_SEARCH_RESULTS_PER_QUERY", 5)
)
self.memory_backend = os.getenv("MEMORY_BACKEND", "local")
self.total_words = int(os.getenv("TOTAL_WORDS", 1000))
self.report_format = os.getenv("REPORT_FORMAT", "APA")
self.max_iterations = int(os.getenv("MAX_ITERATIONS", 3))
self.agent_role = os.getenv("AGENT_ROLE", None)
self.scraper = os.getenv("SCRAPER", "bs")
self.max_subtopics = os.getenv("MAX_SUBTOPICS", 3)
self.report_source = os.getenv("REPORT_SOURCE", None)
self.doc_path = os.getenv("DOC_PATH", "")
self.llm_kwargs = {}
self.llm_kwargs = {}

self.load_config_file()
if not hasattr(self, "llm_kwargs"):
Expand All @@ -47,14 +56,26 @@ def __init__(self, config_file: str = None):
def parse_retrievers(self, retriever_str: str):
"""Parse the retriever string into a list of retrievers and validate them."""
VALID_RETRIEVERS = [
"arxiv", "bing", "custom", "duckduckgo", "exa", "google", "searx",
"semantic_scholar", "serpapi", "serper", "tavily", "pubmed_central"
"arxiv",
"bing",
"custom",
"duckduckgo",
"exa",
"google",
"searx",
"semantic_scholar",
"serpapi",
"serper",
"tavily",
"pubmed_central",
]
retrievers = [retriever.strip() for retriever in retriever_str.split(',')]
retrievers = [retriever.strip() for retriever in retriever_str.split(",")]
invalid_retrievers = [r for r in retrievers if r not in VALID_RETRIEVERS]
if invalid_retrievers:
raise ValueError(f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. "
f"Valid options are: {', '.join(VALID_RETRIEVERS)}.")
raise ValueError(
f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. "
f"Valid options are: {', '.join(VALID_RETRIEVERS)}."
)
return retrievers

def validate_doc_path(self):
Expand Down
6 changes: 4 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from backend.server import app
from dotenv import load_dotenv

load_dotenv()

from backend.server import app

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ markdown
langchain>=0.2,<0.3
langchain_community>=0.2,<0.3
langchain-openai>=0.1,<0.2
langgraph
tiktoken
gpt-researcher
arxiv
Expand Down

0 comments on commit e0e8524

Please sign in to comment.