Skip to content

Commit

Permalink
Merge pull request #2458 from hlohaus/neww
Browse files Browse the repository at this point in the history
Fix get_models in Airforce provider
  • Loading branch information
hlohaus authored Dec 5, 2024
2 parents 636807d + c262f94 commit 9fa15c9
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from urllib.parse import quote
from aiohttp import ClientSession

from ..typing import AsyncResult, Messages
Expand Down Expand Up @@ -69,22 +68,24 @@ def fetch_completions_models(cls):
@classmethod
def fetch_imagine_models(cls):
response = requests.get(
'https://api.airforce/imagine/models',
'https://api.airforce/v1/imagine2/models',
verify=False
)
response.raise_for_status()
return response.json()

image_models = fetch_imagine_models.__get__(None, object)() + additional_models_imagine

@classmethod
def is_image_model(cls, model: str) -> bool:
return model in cls.image_models

models = list(dict.fromkeys([default_model] +
fetch_completions_models.__get__(None, object)() +
image_models))
@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = cls.fetch_imagine_models() + cls.additional_models_imagine
cls.models = list(dict.fromkeys([cls.default_model] +
cls.fetch_completions_models() +
cls.image_models))
return cls.models

@classmethod
async def check_api_key(cls, api_key: str) -> bool:
Expand Down Expand Up @@ -133,7 +134,7 @@ async def generate_image(
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
if response.status == 200:
image_url = str(response.url)
yield ImageResponse(images=image_url, alt=f"Generated image: {prompt}")
yield ImageResponse(images=image_url, alt=prompt)
else:
error_text = await response.text()
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
Expand Down Expand Up @@ -215,12 +216,12 @@ async def create_async_generator(
if not await cls.check_api_key(api_key):
pass

model = cls.get_model(model)
if cls.is_image_model(model):
if prompt is None:
prompt = messages[-1]['content']
if seed is None:
seed = random.randint(0, 10000)

async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
yield result
else:
Expand Down

0 comments on commit 9fa15c9

Please sign in to comment.