-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGemini.py
47 lines (40 loc) · 1.46 KB
/
Gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from LLM import LLM
import torch
import google.generativeai as genai
from google.ai.generativelanguage import Part, Content
import json
from datetime import datetime
class Gemini(LLM):
def load_model(self):
self.id = 7
genai.configure(api_key='AIzaSyB3MhiTdLd7KFC08sR-EBNjWO1M8ZNeYj8')
generation_config = {
"temperature": 0,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 500,
}
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
}
self.model = genai.GenerativeModel(
'gemini-pro', generation_config=generation_config, safety_settings=safety_settings)
def generate(self, prompt: str) -> str:
messages = [
Content(
parts=[
Part(
text="You are an AI assistant that answers Place related MCQ questions."
),
],
role="model"
)
]
chat = self.model.start_chat(history=messages)
response = chat.send_message(prompt)
print(response.text)
return response.text