Skip to content

Commit

Permalink
Add Tavily api (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Feb 13, 2024
1 parent bc220ad commit 3bccc7e
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Support for [Databricks Foundation Models API](https://docs.databricks.com/en/machine-learning/foundation-models/index.html). Requires two environment variables to be set: `DATABRICKS_API_KEY` and `DATABRICKS_HOST` (the part of the URL before `/serving-endpoints/`)
- Experimental support for API tools to enhance your LLM workflows: `Experimental.APITools.create_websearch` function which can execute and summarize a web search (incl. filtering on specific domains). It requires `TAVILY_API_KEY` to be set in the environment. Get your own key from [Tavily](https://tavily.com/) - the free tier enables c. 1000 searches/month, which should be more than enough to get started.

### Fixed
- Added an option to reduce the "batch size" for the embedding step in building the RAG index (`build_index`, `get_embeddings`). Set `embedding_kwargs = (; target_batch_size_length=10_000, ntasks=1)` if you're having some limit issues with your provider.
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ makedocs(;
"Experimental Modules" => "reference_experimental.md",
"RAGTools" => "reference_ragtools.md",
"AgentTools" => "reference_agenttools.md",
"APITools" => "reference_apitools.md",
],
])

Expand Down
9 changes: 9 additions & 0 deletions docs/src/reference_apitools.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Reference for APITools

```@index
Modules = [PromptingTools.Experimental.APITools]
```

```@autodocs
Modules = [PromptingTools.Experimental.APITools]
```
10 changes: 10 additions & 0 deletions src/Experimental/APITools/APITools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module APITools

using HTTP, JSON3
using PromptingTools
const PT = PromptingTools

export create_websearch
include("tavily_api.jl")

end # module
81 changes: 81 additions & 0 deletions src/Experimental/APITools/tavily_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
tavily_api(;
api_key::AbstractString,
endpoint::String = "search",
url::AbstractString = "https://api.tavily.com",
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
Sends API requests to [Tavily](https://tavily.com) and returns the response.
"""
function tavily_api(;
api_key::AbstractString,
endpoint::String = "search",
url::AbstractString = "https://api.tavily.com",
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
@assert !isempty(api_key) "Tavily `api_key` cannot be empty. Check `PT.TAVILY_API_KEY` or pass it as a keyword argument."

## api_key is sent in the POST body, not headers
input_body = Dict("api_key" => api_key, kwargs...)

# eg, https://api.tavily.com/search
api_url = string(url, "/", endpoint)
headers = PT.auth_header(nothing) # no API key provided
resp = HTTP.post(api_url, headers,
JSON3.write(input_body); http_kwargs...)
body = JSON3.read(resp.body)
return (; response = body, status = resp.status)
end

"""
create_websearch(query::AbstractString;
api_key::AbstractString,
search_depth::AbstractString = "basic")
# Arguments
- `query::AbstractString`: The query to search for.
- `api_key::AbstractString`: The API key to use for the search. Get an API key from [Tavily](https://tavily.com).
- `search_depth::AbstractString`: The depth of the search. Can be either "basic" or "advanced". Default is "basic". Advanced search calls equal to 2 requests.
- `include_answer::Bool`: Whether to include the answer in the search results. Default is `false`.
- `include_raw_content::Bool`: Whether to include the raw content in the search results. Default is `false`.
- `max_results::Integer`: The maximum number of results to return. Default is 5.
- `include_images::Bool`: Whether to include images in the search results. Default is `false`.
- `include_domains::AbstractVector{<:AbstractString}`: A list of domains to include in the search results. Default is an empty list.
- `exclude_domains::AbstractVector{<:AbstractString}`: A list of domains to exclude from the search results. Default is an empty list.
# Example
```julia
r = create_websearch("Who is King Charles?")
```
Even better, you can get not just the results but also the answer:
```julia
r = create_websearch("Who is King Charles?"; include_answer = true)
```
See [Rest API documentation](https://docs.tavily.com/docs/tavily-api/rest_api) for more information.
"""
function create_websearch(query::AbstractString;
api_key::AbstractString = PT.TAVILY_API_KEY,
search_depth::AbstractString = "basic",
include_answer::Bool = false,
include_raw_content::Bool = false,
max_results::Integer = 5,
include_images::Bool = false,
include_domains::AbstractVector{<:AbstractString} = String[],
exclude_domains::AbstractVector{<:AbstractString} = String[],)
@assert search_depth in ["basic", "advanced"] "Search depth must be either 'basic' or 'advanced'"
@assert max_results>0 "Max results must be a positive integer"

tavily_api(; api_key, endpoint = "search",
query,
search_depth,
include_answer,
include_raw_content,
max_results,
include_images,
include_domains,
exclude_domains)
end
4 changes: 4 additions & 0 deletions src/Experimental/Experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ It is not included in the main module, so it must be explicitly imported.
Contains:
- `RAGTools`: Retrieval-Augmented Generation (RAG) functionality.
- `AgentTools`: Agentic functionality - lazy calls for building pipelines (eg, `AIGenerate`) and `AICodeFixer`.
- `APITools`: APIs to complement GenAI workflows (eg, Tavily Search API).
"""
module Experimental

Expand All @@ -16,4 +17,7 @@ include("RAGTools/RAGTools.jl")
export AgentTools
include("AgentTools/AgentTools.jl")

export APITools
include("APITools/APITools.jl")

end # module Experimental
7 changes: 7 additions & 0 deletions src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Check your preferences by calling `get_preferences(key::String)`.
- `COHERE_API_KEY`: The API key for the Cohere API. See [Cohere's documentation](https://docs.cohere.com/docs/the-cohere-platform) for more information.
- `DATABRICKS_API_KEY`: The API key for the Databricks Foundation Model API. See [Databricks' documentation](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html) for more information.
- `DATABRICKS_HOST`: The host for the Databricks API. See [Databricks' documentation](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html) for more information.
- `TAVILY_API_KEY`: The API key for the Tavily Search API. Register [here](https://tavily.com/). See more information [here](https://docs.tavily.com/docs/tavily-api/rest_api).
- `MODEL_CHAT`: The default model to use for aigenerate and most ai* calls. See `MODEL_REGISTRY` for a list of available models or define your own.
- `MODEL_EMBEDDING`: The default model to use for aiembed (embedding documents). See `MODEL_REGISTRY` for a list of available models or define your own.
- `PROMPT_SCHEMA`: The default prompt schema to use for aigenerate and most ai* calls (if not specified in `MODEL_REGISTRY`). Set as a string, eg, `"OpenAISchema"`.
Expand All @@ -37,6 +38,7 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava
- `LOCAL_SERVER`: The URL of the local server to use for `ai*` calls. Defaults to `http://localhost:10897/v1`. This server is called when you call `model="local"`
- `DATABRICKS_API_KEY`: The API key for the Databricks Foundation Model API.
- `DATABRICKS_HOST`: The host for the Databricks API.
- `TAVILY_API_KEY`: The API key for the Tavily Search API. Register [here](https://tavily.com/). See more information [here](https://docs.tavily.com/docs/tavily-api/rest_api).
Preferences.jl takes priority over ENV variables, so if you set a preference, it will override the ENV variable.
Expand Down Expand Up @@ -65,6 +67,7 @@ function set_preferences!(pairs::Pair{String, <:Any}...)
"COHERE_API_KEY",
"DATABRICKS_API_KEY",
"DATABRICKS_HOST",
"TAVILY_API_KEY",
"MODEL_CHAT",
"MODEL_EMBEDDING",
"MODEL_ALIASES",
Expand Down Expand Up @@ -103,6 +106,7 @@ function get_preferences(key::String)
"COHERE_API_KEY",
"DATABRICKS_API_KEY",
"DATABRICKS_HOST",
"TAVILY_API_KEY",
"MODEL_CHAT",
"MODEL_EMBEDDING",
"MODEL_ALIASES",
Expand Down Expand Up @@ -140,6 +144,9 @@ const DATABRICKS_API_KEY::String = @noinline @load_preference("DATABRICKS_API_KE
const DATABRICKS_HOST::String = @noinline @load_preference("DATABRICKS_HOST",
default=@noinline get(ENV, "DATABRICKS_HOST", ""));

const TAVILY_API_KEY::String = @noinline @load_preference("TAVILY_API_KEY",
default=@noinline get(ENV, "TAVILY_API_KEY", ""));

## Address of the local server
const LOCAL_SERVER::String = @noinline @load_preference("LOCAL_SERVER",
default=@noinline get(ENV, "LOCAL_SERVER", "http://127.0.0.1:10897/v1"));
Expand Down
23 changes: 16 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,24 @@ end
function preview end

"""
auth_header(api_key::String)
auth_header(api_key::Union{Nothing, AbstractString};
extra_headers::AbstractVector{Pair{String, String}} = Vector{Pair{String, String}}[],
kwargs...)
Builds an authorization header for API calls with the given API key.
Creates the authentication headers for any API request. Assumes that the communication is done in JSON format.
"""
function auth_header(api_key::String)
isempty(api_key) && throw(ArgumentError("api_key cannot be empty"))
[
"Authorization" => "Bearer $api_key",
function auth_header(api_key::Union{Nothing, AbstractString};
extra_headers::AbstractVector = Vector{
Pair{String, String},
}[],
kwargs...)
!isnothing(api_key) && isempty(api_key) &&
throw(ArgumentError("`api_key` cannot be empty"))
headers = [
"Content-Type" => "application/json",
"Accept" => "application/json",
extra_headers...,
]
end
!isnothing(api_key) && pushfirst!(headers, "Authorization" => "Bearer $api_key")
return headers
end
8 changes: 8 additions & 0 deletions test/Experimental/APITools/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Test
using PromptingTools
using PromptingTools.Experimental.APITools
const PT = PromptingTools

## @testset "APITools" begin
## include("tavily_api.jl")
## end
1 change: 1 addition & 0 deletions test/Experimental/APITools/tavily_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO: hard to test the API itself?
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ end
@testset "Experimental" begin
include("Experimental/RAGTools/runtests.jl")
include("Experimental/AgentTools/runtests.jl")
include("Experimental/APITools/runtests.jl")
end
1 change: 1 addition & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,5 @@ end
"Accept" => "application/json",
]
@test_throws ArgumentError auth_header("")
@test length(auth_header(nothing)) == 2
end

0 comments on commit 3bccc7e

Please sign in to comment.