Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patrick li/button #71

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TriggerDescribeIn,
SendMessageResponse,
SendMessageIn,
ActionTriggerIn,
)

router = APIRouter()
Expand Down Expand Up @@ -64,6 +65,14 @@ async def describe(body: TriggerDescribeIn):
return {"trigger_id": trigger_id, "trigger_type": trigger_type}


@router.post("/action", response_model=TriggerResponse)
async def action(body: ActionTriggerIn):
trigger_id = body.trigger_id
trigger_type = TriggerType.action.value

taskqueue.put(trigger_id, discord.trigger_action, **body.dict())
return {"trigger_id": trigger_id, "trigger_type": trigger_type}

@router.post("/upload", response_model=UploadResponse)
async def upload_attachment(file: UploadFile):
if not file.content_type.startswith("image/"):
Expand Down
4 changes: 4 additions & 0 deletions app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class TriggerResponse(BaseModel):
trigger_id: str
trigger_type: str = ""

class ActionTriggerIn(BaseModel):
trigger_id: str = ""
custom_id: str = ""
msg_id: str = ""

class UploadResponse(BaseModel):
message: str = "success"
Expand Down
12 changes: 12 additions & 0 deletions lib/api/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TriggerType(str, Enum):
max_upscale = "max_upscale"
reset = "reset"
describe = "describe"
action = "action"


async def trigger(payload: Dict[str, Any]):
Expand Down Expand Up @@ -167,6 +168,17 @@ async def reset(msg_id: str, msg_hash: str, **kwargs):
}, **kwargs)
return await trigger(payload)

async def trigger_action(msg_id: str, custom_id: str, **kwargs):
kwargs = {
"message_flags": 0,
"message_id": msg_id,
}
payload = _trigger_payload(3, {
"component_type": 2,
"custom_id": custom_id
}, **kwargs)
return await trigger(payload)


async def describe(upload_filename: str, **kwargs):
payload = _trigger_payload(2, {
Expand Down
7 changes: 6 additions & 1 deletion task/bot/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ class Embed(TypedDict):
image: EmbedsImage


class Action(TypedDict):
label: str
custom_id: str


class CallbackData(TypedDict):
type: str
id: int
content: str
attachments: List[Attachment]
embeds: List[Embed]

actions: List[Action]
trigger_id: str
16 changes: 14 additions & 2 deletions task/bot/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import re
from typing import Dict, Union, Any

from discord import Message
from discord import Message, components, ActionRow

from app.handler import PROMPT_PREFIX, PROMPT_SUFFIX
from lib.api.callback import queue_release, callback
from task.bot._typing import CallbackData, Attachment, Embed
from task.bot._typing import CallbackData, Attachment, Embed, Action

TRIGGER_ID_PATTERN = f"{PROMPT_PREFIX}(\w+?){PROMPT_SUFFIX}" # 消息 ID 正则

Expand All @@ -33,6 +33,16 @@ def match_trigger_id(content: str) -> Union[str, None]:
match = re.findall(TRIGGER_ID_PATTERN, content)
return match[0] if match else None

def get_action(message: Message) -> list[Action]:
res = []
for component in message.components:
if isinstance(component, ActionRow):
for action in component.children:
if isinstance(action, components.Button) and action.label and action.custom_id:
res.append(Action(label = action.label, custom_id = action.custom_id))
return res



async def callback_trigger(trigger_id: str, trigger_status: str, message: Message):
await callback(CallbackData(
Expand All @@ -45,6 +55,7 @@ async def callback_trigger(trigger_id: str, trigger_status: str, message: Messag
],
embeds=[],
trigger_id=trigger_id,
actions=get_action(message)
))


Expand All @@ -61,5 +72,6 @@ async def callback_describe(trigger_status: str, message: Message, embed: Dict[s
Embed(**embed)
],
trigger_id=trigger_id,
actions=get_action(message)
))
return trigger_id
1 change: 1 addition & 0 deletions util/_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def put(
self._wait_queue.append({
_trigger_id: Task(func, *args, **kwargs)
})
logger.debug(f"Task[{_trigger_id}] added to queue. Queue size: {len(self._wait_queue)}")
while self._wait_queue and len(self._concur_queue) < self._concur_size:
self._exec()

Expand Down