Skip to content

Commit

Permalink
Convert EE agent to use chat API to avoid having to pass message hist…
Browse files Browse the repository at this point in the history
…ory manually.

PiperOrigin-RevId: 626629003
  • Loading branch information
The Google Earth Engine Community Authors authored and copybara-github committed Apr 26, 2024
1 parent 0842996 commit 74c4c4b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 57 deletions.
69 changes: 36 additions & 33 deletions experimental/ee_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def show_image(url: str):
class LLM:
"""Parent LLM class."""

def ask(self, question: str, temperature: Optional[float] = None) -> str:
raise NotImplementedError("Subclasses must implement the 'ask' method.")
def chat(self, question: str, temperature: Optional[float] = None) -> str:
raise NotImplementedError("Subclasses must implement the 'chat' method.")

def analyze_image(self, url: str) -> str:
raise NotImplementedError(
Expand All @@ -144,34 +144,42 @@ class Gemini(LLM):
"""Gemini LLM."""

def __init__(self):
self.text_model = genai.GenerativeModel('gemini-1.5-pro-latest')
self.image_model = genai.GenerativeModel('gemini-pro-vision')
self._text_model = genai.GenerativeModel('gemini-1.5-pro-latest')
self._image_model = genai.GenerativeModel('gemini-pro-vision')
self._chat_proxy = self._text_model.start_chat(history=[])

def ask(self, question: str, temperature: Optional[float] = 0.1) -> str:
def chat(self, question: str, temperature: Optional[float] = None) -> str:
"""Adds a question to the ongoing chat session."""
time.sleep(1)
sleep_duration = 10
while True:
response = ''
try:
response = self.text_model.generate_content(
brief + question,
generation_config={
'temperature': temperature,
# Use a generous but limited output size to encourage in-depth
# replies.
'max_output_tokens': 5000,
},
response = self._chat_proxy.send_message(
brief + question,
generation_config={
'temperature': temperature,
# Use a generous but limited output size to encourage in-depth
# replies.
'max_output_tokens': 5000,
}
)
if not response.parts:
raise ValueError(
'Cannot get analysis with reason'
f' {response.candidates[0].finish_reason.name}, terminating'
)
return response.text
except google.api_core.exceptions.TooManyRequests:
except (
google.api_core.exceptions.TooManyRequests,
google.api_core.exceptions.DeadlineExceeded
):
print(
'Got a TooManyRequests error, sleeping for'
'Got a rate limit or timeout error, sleeping for'
f' {sleep_duration} seconds'
)
time.sleep(sleep_duration)
continue
return response.text

def analyze_image(self, url: str) -> str:
image = PIL.Image.open(io.BytesIO(get_image(url)))
Expand Down Expand Up @@ -208,7 +216,7 @@ def analyze_image(self, url: str) -> str:
{'inline_data': image},
]
}
image_response = self.image_model.generate_content(req)
image_response = self._image_model.generate_content(req)
try:
return image_response.text
except ValueError as e:
Expand Down Expand Up @@ -283,17 +291,15 @@ def run_llm_code(
error_count = 0

while True:
answer = llm.ask(prompt, temperature=old_context.temperature)
answer = llm.chat(prompt, temperature=old_context.temperature)
print(f'\nANSWER:\n{answer}\n')

new_context = CodeContext()
new_context.code = extract_code(answer)
if new_context.code == old_context.code:
prompt = (
f"""You are an expert in reasoning about code errors and fixing them. I
asked the following question: ***{question}***\n\nYou generated
this code again:\n{new_context.code}\n\nThis code produced an error:
{old_context.error}\n\nPlease try something different."""
"""This is the same code you suggested before, which still
generates the same error. Please try something different."""
)
new_context.temperature = 1
old_context = new_context
Expand All @@ -315,9 +321,7 @@ def run_llm_code(
new_context.error = str(e)
print(f'ERROR:\n{new_context.error}')
prompt = (
f"""You are an expert in reasoning about code errors and fixing them. I
asked the following question ***{prompt}*** and ran this code: {new_context.code}
which produced an error ***{new_context.error}***"""
f"""This code produced an error, please fix it. ***{new_context.error}***"""
)

if old_context.error == new_context.error:
Expand All @@ -344,8 +348,8 @@ def run_llm_code(


def run_agent(
text_llm: LLM, image_llm: LLM, topic: str, question: str,
recommendation: str = '') -> None:
text_llm: LLM, image_llm: LLM, question: str, recommendation: str = ''
) -> None:
"""Outer loop running the agent until a high score is reached."""
codes: list[str] = []
evals: list[str] = []
Expand All @@ -360,7 +364,7 @@ def run_agent(
print(f'\nIMAGE ANALYSIS:\n{analysis}')

eval_question = (
f"""As a response to a user question to show {topic} an Earth Engine map
f"""An Earth Engine map
tile was produced containing the following: ***{analysis}***. Start
your answer with a number between 0 and 1 indicating how relevant
the image is as an answer to the user question. Discuss whether it
Expand All @@ -387,7 +391,7 @@ def run_agent(
shows Bay Area Cons: * The color palette does not indicate NDVI even
though an NDVI image of the Bay Area was requested. """
)
evaluation = text_llm.ask(eval_question)
evaluation = text_llm.chat(eval_question)
print(f'\nEVALUATION:\n{evaluation}')

score_match = re.search(
Expand Down Expand Up @@ -415,14 +419,13 @@ def run_agent(
importance. Limit output to five suggestions.\n\nThis is tile
fretrieval round {round_num}. The higher the round, the more you
should consider trying a different approach or a different
geometry than earlier.\n\nThe history is below in the
format:\nCODE 1\n...\nEVAL 1\n...\nCODE 2\n...\nEVAL 2\n...\n\n"""
geometry than earlier."""
)
for i, (old_code, old_eval) in enumerate(zip(codes, evals)):
recommendation_question += (
f'CODE {i+1}\n{old_code}\n\nEVAL {i+1}\n{old_eval}\n\n'
)
recommendation = text_llm.ask(recommendation_question)
recommendation = text_llm.chat(recommendation_question)
print(f'\nRECOMMENDATIONS:\n{recommendation}')
else:
print(url)
Expand Down Expand Up @@ -463,7 +466,7 @@ def main(argv: list[str]) -> None:
'region': geometry, 'dimensions': 512})."""
)

run_agent(text_llm, image_llm, topic, question)
run_agent(text_llm, image_llm, question)


if __name__ == '__main__':
Expand Down
50 changes: 26 additions & 24 deletions experimental/ee_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_image(self, mock_get):
def test_get_tile_url_success(self, mock_exec):
mock_exec.return_value = 'https://earthengine.googleapis.com/tile_url'
mock_llm = mock.MagicMock()
mock_llm.ask.return_value = (
mock_llm.chat.return_value = (
'```python\nprint("https://earthengine.googleapis.com/tile_url")\n```'
)
url, code = ee_agent.get_tile_url_and_code(
Expand All @@ -52,7 +52,7 @@ def test_get_tile_url_success(self, mock_exec):
def test_get_tile_url_bad_url(self, mock_exec):
mock_exec.side_effect = ee_agent.AgentError('BAD URL')
mock_llm = mock.MagicMock()
mock_llm.ask.return_value = '```python\nprint("invalid_url")\n```'
mock_llm.chat.return_value = '```python\nprint("invalid_url")\n```'

with self.assertRaisesRegex(ee_agent.AgentError, 'BAD URL'):
ee_agent.get_tile_url_and_code(mock_llm, 'question', 'recommendation')
Expand All @@ -63,7 +63,7 @@ def test_get_tile_url_bad_url(self, mock_exec):
def test_get_tile_url_code_error(self, mock_exec):
del mock_exec # unused
mock_llm = mock.MagicMock()
mock_llm.ask.side_effect = [
mock_llm.chat.side_effect = [
'```python\nprint("url1")\n```',
'```python\nprint("url2")\n```',
'```python\nprint("https://earthengine.googleapis.com/tile_url")\n```',
Expand All @@ -77,7 +77,7 @@ def test_get_tile_url_repeated_code_error(self):
question = 'What is the Earth Engine URL for a Landsat 8 image?'
recommendation = None

mock_llm.ask.side_effect = [
mock_llm.chat.side_effect = [
'Here is the code to get the URL:\n```python\nraise ValueError("Test error")\n```',
'Here is the code to get the URL:\n```python\nraise ValueError("Test error")\n```',
'Here is the revised code:\n```python\nprint("https://earthengine.googleapis.com/api/thumb?thumbid=abc123")\n```'
Expand All @@ -93,12 +93,12 @@ def test_get_tile_url_repeated_code_error(self):
expected_prompts = [
question,
'Revise the code and output a new version. Think about the broader\n context of the question and how well the code matches this\n context.\'\nThis is error number 1. The higher the error count,\n the more you should revise the code, possibly starting from\n scratch.',
'You are an expert in reasoning about code errors and fixing them. I\n asked the following question: ***What is the Earth Engine URL for a Landsat 8 image?***\n\nYou generated\n this code again:\nraise ValueError("Test error")\n\nThis code produced an error:\n Test error\n\nPlease try something different.'
'This is the same code you suggested before, which still\n generates the same error. Please try something different.'
]
expected_temps = [0, 0.5, 1]

for call, expected_prompt, expected_temp in zip(
mock_llm.ask.call_args_list, expected_prompts, expected_temps):
mock_llm.chat.call_args_list, expected_prompts, expected_temps):
self.assertEqual(expected_prompt, call[0][0])
self.assertEqual(expected_temp, call[1]['temperature'])

Expand All @@ -107,7 +107,7 @@ def test_get_tile_url_repeated_error(self):
question = 'What is the Earth Engine URL for a Landsat 8 image?'
recommendation = None

mock_llm.ask.side_effect = [
mock_llm.chat.side_effect = [
'Here is the code to get the URL:\n```python\nraise ValueError("Test error 1")\n```',
'Here is the revised code:\n```python\nraise ValueError("Test error 2")\n```',
'Here is the revised code:\n```python\nprint("https://earthengine.googleapis.com/api/thumb?thumbid=abc123")\n```'
Expand All @@ -128,7 +128,7 @@ def test_get_tile_url_repeated_error(self):
expected_temps = [0, 0.5, 0.5]

for call, expected_prompt, expected_temp in zip(
mock_llm.ask.call_args_list, expected_prompts, expected_temps):
mock_llm.chat.call_args_list, expected_prompts, expected_temps):
self.assertEqual(expected_prompt, call[0][0])
self.assertEqual(expected_temp, call[1]['temperature'])

Expand All @@ -145,7 +145,7 @@ def test_analyze_image(self, mock_get):
'Image analysis result'
)
gemini = ee_agent.Gemini()
gemini.image_model = mock_image_model
gemini._image_model = mock_image_model

analysis = gemini.analyze_image('https://example.com/image.jpg')
self.assertEqual(analysis, 'Image analysis result')
Expand All @@ -163,33 +163,35 @@ def test_analyze_image_uniform_color(self, mock_get):
'Image analysis result'
)
gemini = ee_agent.Gemini()
gemini.image_model = mock_image_model
gemini._image_model = mock_image_model

analysis = gemini.analyze_image('https://example.com/image.jpg')
self.assertEqual(
analysis, 'The image tile has a single uniform color with value 0.')

@mock.patch('google.generativeai.GenerativeModel')
def test_ask(self, mock_generative_model):
mock_text_model = mock.MagicMock()
mock_text_model.generate_content.return_value.text = 'Generated answer'
mock_generative_model.return_value = mock_text_model
def test_chat(self, mock_generative_model):
mock_chat_proxy = mock.MagicMock()
mock_chat_proxy.send_message.return_value.text = 'Generated answer'

mock_model = mock.MagicMock()
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat_proxy

gemini = ee_agent.Gemini()
answer = gemini.ask('question')
answer = gemini.chat('question')
self.assertEqual(answer, 'Generated answer')

@mock.patch.object(ee_agent.Gemini, 'analyze_image')
@mock.patch(
'ee_agent.get_tile_url_and_code',
return_value=('https://earthengine.googleapis.com/tile_url', 'code'),
)
@mock.patch('ee_agent.Gemini.ask')
@mock.patch('ee_agent.Gemini.chat')
def test_run_agent_success(
self, mock_ask, mock_get_tile_url, mock_analyze_image
):
self, mock_chat, mock_get_tile_url, mock_analyze_image):
mock_analyze_image.return_value = 'Image analysis result'
mock_ask.return_value = '0.95 Evaluation result'
mock_chat.return_value = '0.95 Evaluation result'

ee_agent.run_agent(
ee_agent.Gemini(), ee_agent.Gemini(), 'topic', 'question'
Expand All @@ -202,15 +204,15 @@ def test_run_agent_success(
'ee_agent.get_tile_url_and_code',
return_value=('code', 'https://earthengine.googleapis.com/tile_url'),
)
@mock.patch('ee_agent.Gemini.ask')
@mock.patch('ee_agent.Gemini.chat')
def test_run_agent_multiple_attempts(
self, mock_ask, mock_get_tile_url, mock_analyze_image
):
self, mock_chat, mock_get_tile_url, mock_analyze_image):
mock_analyze_image.return_value = 'Image analysis result'
mock_ask.side_effect = [

mock_chat.side_effect = [
'0.5 Evaluation result',
'recommendation',
'0.95 Evaluation result',
'0.95 Evaluation result'
]

ee_agent.run_agent(
Expand Down

0 comments on commit 74c4c4b

Please sign in to comment.