-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoint_summary_filter.py
818 lines (669 loc) · 27.4 KB
/
checkpoint_summary_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
"""
title: Checkpoint Summary Filter
author: projectmoon
author_url: https://git.agnos.is/projectmoon/open-webui-filters
version: 0.2.2
license: AGPL-3.0+
required_open_webui_version: 0.3.32
"""
# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
# System imports
import asyncio
import hashlib
import uuid
import json
import re
import logging
from typing import Optional, List, Dict, Callable, Any, NewType, Tuple, Awaitable, ClassVar
from typing_extensions import TypedDict, NotRequired
from collections import deque
# Libraries available to OpenWebUI
from pydantic import BaseModel as PydanticBaseModel, Field
import chromadb
from chromadb import Collection as ChromaCollection
from chromadb.api.types import Document as ChromaDocument
# OpenWebUI imports
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.main import app as rag_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import show_model_info, ModelNameForm
from open_webui.utils.misc import get_last_user_message, get_last_assistant_message
from open_webui.main import generate_chat_completions
from open_webui.apps.webui.models.chats import Chats
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import Users
# Why refactor when you can janky monkey patch? This will be fixed at
# some point.
CHROMA_CLIENT = VECTOR_DB_CLIENT.client
# Embedding (not yet used)
EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
EmbeddingFunc = NewType('EmbeddingFunc', Callable[[str], List[Any]])
# Prompts
SUMMARIZER_PROMPT = """
You are a chat conversation summarizer. Your task is to summarize the given
portion of an ongoing conversation. First, determine if the conversation is
a regular chat between the user and the assistant, or if the conversation is
part of a story or role-playing session.
Summarize the important parts of the given chat between the user and the
assistant. Limit your summary to one paragraph. Make sure your summary is
detailed. Write the summary as if you are summarizing part of a larger
conversation. Do not refer to "you" or "me" in the summary. Write in the
third person perspective.
If the conversation is a regular chat, write your summary referring to the
ongoing conversation as a chat. If the conversation is a regular chat, refer
to the user and the assistant as user and assistant. If the conversation is
a regular chat, do not refer to yourself as the assistant. Do not make up a
name for the user. If the conversation is a regular chat, summarize all
important parts of the chat.
If the conversation is a story or role-playing session, write your summary
referring to the conversation as an ongoing story. If the conversation is a
story or roleplaying session, do not refer to the useror assistant in your
summary. If the conversation is a story or roleplaying sesison, only use the
names of the characters, places, and events in the story.
""".replace("\n", " ").strip()
# yoinked from stack overflow. hack to get user into
# generate_chat_completions. this is used to turn the __user__ dict
# given to the filter into a thing that the main OpenWebUI code can
# understand for calling its chat completion endpoint internally.
class BlackMagicDictionary(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Message(TypedDict):
id: NotRequired[str]
role: str
content: str
class MessageInsertMetadata(TypedDict):
role: str
chapter: str
class MessageInsert(TypedDict):
message_id: str
content: str
metadata: MessageInsertMetadata
embeddings: List[Any]
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class SummarizerResponse(BaseModel):
summary: str
class Summarizer(BaseModel):
messages: List[dict]
model: str
user: Any
prompt: str = SUMMARIZER_PROMPT
async def summarize(self) -> Optional[SummarizerResponse]:
sys_message: Message = { "role": "system", "content": SUMMARIZER_PROMPT }
user_message: Message = {
"role": "user",
"content": "Make a detailed summary of the conversation up to this point."
}
messages = [sys_message] + self.messages + [user_message]
request = {
"model": self.model,
"messages": messages,
"stream": False,
"keep_alive": "10s"
}
resp = await generate_chat_completions(request, user=self.user)
if "choices" in resp and len(resp["choices"]) > 0:
content: str = resp["choices"][0]["message"]["content"]
return SummarizerResponse(summary=content)
else:
return None
class Checkpoint(BaseModel):
# chat id
chat_id: str
# the message ID this checkpoint was created from.
message_id: str
# index of the message in the message input array. in the inlet
# function, we do not have access to incoming message ids for some
# reason. used as a fallback to drop old context when
message_index: int = 0
# the "slug", or chain of messages, that led to this point.
slug: str
# actual summary of messages.
summary: str
# if we try to put a type hint on this, it gets mad.
@staticmethod
def from_json(obj: dict):
try:
return Checkpoint(
chat_id=obj["chat_id"],
message_id=obj["message_id"],
message_index=obj["message_index"],
slug=obj["slug"],
summary=obj["summary"]
)
except:
return None
def to_json(self) -> str:
return self.model_dump_json()
class Checkpointer(BaseModel):
"""Manages summary checkpoints in a single chat."""
chat_id: str
summarizer_model: str = ""
chroma_client: chromadb.ClientAPI
messages: List[dict]=[] # stripped set of messages
full_messages: List[dict]=[] # all the messages
embedding_func: EmbeddingFunc=(lambda a: 0)
user: Optional[Any] = None
collection_name: ClassVar[str] = "chat_checkpoints"
def _get_collection(self) -> ChromaCollection:
return self.chroma_client.get_or_create_collection(
name=Checkpointer.collection_name
)
def _insert_checkpoint(self, checkpoint: Checkpoint):
coll = self._get_collection()
checkpoint_doc = checkpoint.to_json()
# Insert the checkpoint itself with slug as ID.
coll.upsert(
ids=[checkpoint.slug],
documents=[checkpoint_doc],
metadatas=[{ "chat_id": self.chat_id, "type": "checkpoint" }],
embeddings=[self.embedding_func(checkpoint_doc)]
)
# Update the chat info doc for this chat.
coll.upsert(
ids=[self.chat_id],
documents=[json.dumps({ "current_checkpoint": checkpoint.slug })],
embeddings=[self.embedding_func(self.chat_id)]
)
def _calculate_slug(self) -> Optional[str]:
if len(self.messages) == 0:
return None
message_ids = [msg["id"] for msg in reversed(self.messages)]
slug = "|".join(message_ids)
return hashlib.sha256(slug.encode()).hexdigest()
def _get_state(self):
resp = self._get_collection().get(ids=[self.chat_id], include=["documents"])
state: dict = (json.loads(resp["documents"][0])
if resp["documents"] and len(resp["documents"]) > 0
else { "current_checkpoint": None })
return state
def _find_message_index(self, message_id: str) -> Optional[int]:
for idx, message in enumerate(self.full_messages):
if message["id"] == message_id:
return idx
return None
def nuke_checkpoints(self):
"""Delete all checkpoints for this chat."""
coll = self._get_collection()
checkpoints = coll.get(
include=["documents"],
where={"chat_id": self.chat_id}
)
self._get_collection().delete(
ids=[self.chat_id] + checkpoints["ids"]
)
async def create_checkpoint(self) -> str:
summarizer = Summarizer(model=self.summarizer_model, messages=self.messages, user=self.user)
resp = await summarizer.summarize()
if resp:
slug = self._calculate_slug()
checkpoint_message = self.messages[-1]
checkpoint_index = self._find_message_index(checkpoint_message["id"])
checkpoint = Checkpoint(
chat_id = self.chat_id,
slug = self._calculate_slug(),
message_id = checkpoint_message["id"],
message_index = checkpoint_index,
summary = resp.summary
)
self._insert_checkpoint(checkpoint)
return slug
def get_checkpoint(self, slug: Optional[str]) -> Optional[Checkpoint]:
if not slug:
return None
resp = self._get_collection().get(ids=[slug], include=["documents"])
checkpoint = (resp["documents"][0]
if resp["documents"] and len(resp["documents"]) > 0
else None)
if checkpoint:
return Checkpoint.from_json(json.loads(checkpoint))
else:
return None
def get_current_checkpoint(self) -> Optional[Checkpoint]:
state = self._get_state()
return self.get_checkpoint(state["current_checkpoint"])
#########################
# Utilities
#########################
class SessionInfo(BaseModel):
chat_id: str
message_id: str
session_id: str
def extract_session_info(event_emitter) -> Optional[SessionInfo]:
"""The latest innovation in hacky workarounds."""
try:
info = event_emitter.__closure__[0].cell_contents
return SessionInfo(
chat_id=info["chat_id"],
message_id=info["message_id"],
session_id=info["session_id"]
)
except:
return None
def predicted_token_use(messages) -> int:
"""Parse most recent message to calculate estimated token use."""
if len(self.messages == 0):
return 0
# Naive assumptions:
# - 1 word = 1 token.
# - 1 period, comma, or colon = 1 token
message = messages[-1]
return len(list(filter(None, re.split(r"\s|(;)|(,)|(\.)|(:)|\n", message))))
def is_big_convo(messages, num_ctx: int=8192) -> bool:
"""
Attempt to detect large pre-existing conversation by looking at
recent eval counts from messages and comparing against given
num_ctx. We check all messages for an eval count that goes above
the context limit. It doesn't matter where in the message list; if
it's somewhere in the middle, it means that there was a context
shift.
"""
for message in messages:
if "info" in message:
eval_count = (message["info"]["eval_count"]
if "eval_count" in message["info"]
else 0)
prompt_eval_count = (message["info"]["prompt_eval_count"]
if "prompt_eval_count" in message["info"]
else 0)
tokens_used = eval_count + prompt_eval_count
else:
tokens_used = 0
if tokens_used >= num_ctx:
return True
return False
def hit_context_limit(
messages,
num_ctx: int=8192,
wiggle_room: int=1000
) -> Tuple[bool, int]:
"""
Determine if we've hit the context limit, within some reasonable
estimation. We have a defined 'wiggle room' that is subtracted
from the num_ctx parameter, in order to capture near-filled
contexts. We do it this way because we're summarizing on output,
rather than before input (inlet function doesn't have enough
info).
"""
if len(messages) == 0:
return False, 0
last_message = messages[-1]
tokens_used = 0
if "info" in last_message:
eval_count = (last_message["info"]["eval_count"]
if "eval_count" in last_message["info"] else 0)
prompt_eval_count = (last_message["info"]["prompt_eval_count"]
if "prompt_eval_count" in last_message["info"] else 0)
tokens_used = eval_count + prompt_eval_count
if tokens_used >= (num_ctx - wiggle_room):
amount_over = tokens_used - num_ctx
amount_over = 0 if amount_over < 0 else amount_over
return True, amount_over
else:
return False, 0
def extract_base_model_id(model: dict) -> Optional[str]:
if "base_model_id" not in model["info"]:
return None
base_model_id = model["info"]["base_model_id"]
if not base_model_id:
base_model_id = model["id"]
return base_model_id
def extract_owu_model_param(model_obj: dict, param_name: str):
"""
Extract a parameter value from the DB definition of a model
that is based on another model.
"""
if not "params" in model_obj["info"]:
return None
params = model_obj["info"]["params"]
return params.get(param_name, None)
def extract_owu_base_model_param(base_model_id: str, param_name: str):
"""Extract a parameter value from the DB definition of an ollama base model."""
base_model = Models.get_model_by_id(base_model_id)
if not base_model:
return None
base_model.params = base_model.params.model_dump()
return base_model.params.get(param_name, None)
def extract_ollama_response_param(model: dict, param_name: str):
"""Extract a parameter value from ollama show API response."""
if "parameters" not in model:
return None
for line in model["parameters"].splitlines():
if line.startswith(param_name):
return line.lstrip(param_name).strip()
return None
async def get_model_from_ollama(model_id: str, user_id) -> Optional[dict]:
"""Call ollama show API and return model information."""
curr_user = Users.get_user_by_id(user_id)
try:
return await show_model_info(ModelNameForm(name=model_id), user=curr_user)
except Exception as e:
print(f"Could not get model info: {e}")
return None
async def calculate_num_ctx(chat_id: str, user_id, model: dict) -> int:
"""
Attempt to discover the current num_ctx parameter in many
different ways.
"""
# first check the open-webui chat parameters.
chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id)
if chat:
# this might look odd, but the chat field is a json blob of
# useful info.
chat = json.loads(chat.chat)
if "params" in chat and "num_ctx" in chat["params"]:
if chat["params"]["num_ctx"] is not None:
return chat["params"]["num_ctx"]
# then check open web ui model def
num_ctx = extract_owu_model_param(model, "num_ctx")
if num_ctx:
return num_ctx
# then check open web ui base model def.
base_model_id = extract_base_model_id(model)
if not base_model_id:
# fall back to default in case of weirdness.
return 2048
num_ctx = extract_owu_base_model_param(base_model_id, "num_ctx")
if num_ctx:
return num_ctx
# THEN check ollama directly.
base_model = await get_model_from_ollama(base_model_id, user_id)
num_ctx = extract_ollama_response_param(base_model, "num_ctx")
if num_ctx:
return num_ctx
# finally, return default.
return 2048
class Filter:
class Valves(BaseModel):
def summarizer_model(self, body):
if self.summarizer_model_id == "":
return extract_base_model_id(body["model"])
else:
return self.summarizer_model_id
summarize_large_contexts: bool = Field(
default=False,
description=(
f"Whether or not to use a large context model to summarize large "
f"pre-existing conversations."
)
)
wiggle_room: int = Field(
default=1000,
description=(
"Amount of token 'wiggle room' for estimating when a context shift occurs. "
"Subtracted from num_ctx when checking if summarization is needed."
)
)
summarizer_model_id: str = Field(
default="",
description="Model used to summarize the conversation. Must be a base model.",
)
large_summarizer_model_id: str = Field(
default="",
description=(
"Model used to summarize large pre-existing contexts. "
"Must be a base model with a context size large enough "
"to fit the conversation."
)
)
pass
class UserValves(BaseModel):
pass
def __init__(self):
self.valves = self.Valves()
pass
def load_current_chat(self) -> dict:
# the chat property of the model is the json blob that holds
# all the interesting stuff
chat = (Chats
.get_chat_by_id_and_user_id(self.session_info.chat_id, self.user["id"])
.chat)
return json.loads(chat)
def get_messages_for_checkpointing(self, messages, num_ctx, last_checkpointed_id):
"""
Assemble list of messages to checkpoint, based on current
state and valve settings.
"""
message_chain = deque()
for message in reversed(messages):
if message["id"] == last_checkpointed_id:
break
message_chain.appendleft(message)
message_chain = list(message_chain) # the lazy way
# now we check if we are a big conversation, and if valve
# settings allow that kind of summarization.
summarizer_model = self.valves.summarizer_model
if is_big_convo(messages, num_ctx) and not self.valves.summarize_large_contexts:
# must summarize using small model. for now, drop to last
# N messages.
print((
"Dropping all but last 4 messages to summarize "
"large convo without large model."
))
message_chain = message_chain[-4:]
return message_chain
async def create_checkpoint(
self,
messages: List[dict],
last_checkpointed_id: Optional[str]=None,
num_ctx: int=8192
):
if len(messages) == 0:
return
print(f"[{self.session_info.chat_id}] Detected context shift. Summarizing.")
await self.set_summarizing_status(done=False)
last_message = messages[-1] # should check for role = assistant
curr_message_id: Optional[str] = (
last_message["id"] if last_message else None
)
if not curr_message_id:
return
# strip messages down to what is in the current checkpoint.
message_chain = self.get_messages_for_checkpointing(
messages, num_ctx, last_checkpointed_id
)
# we should now have a list of messages that is just within
# the current context limit.
summarizer_model = self.valves.summarizer_model_id
if is_big_convo(message_chain, num_ctx) and self.valves.summarize_large_contexts:
print(f"[{self.session_info.chat_id}] Summarizing LARGE context!")
summarizer_model = self.valves.large_summarizer_model_id
checkpointer = Checkpointer(
chat_id=self.session_info.chat_id,
summarizer_model=summarizer_model,
chroma_client=CHROMA_CLIENT,
full_messages=messages,
messages=message_chain,
user=BlackMagicDictionary(self.user)
)
try:
slug = await checkpointer.create_checkpoint()
await self.set_summarizing_status(done=True)
print((f"[{self.session_info.chat_id}] Summarization checkpoint created: "
f"{slug}"))
except Exception as e:
print(f"[{self.session_info.chat_id}] Error creating summary: {str(e)}")
await self.set_summarizing_status(
done=True, message=f"Error summarizing: {str(e)}"
)
def update_chat_with_checkpoint(self, messages: List[dict], checkpoint: Checkpoint):
if len(messages) < checkpoint.message_index:
# do not mess with anything if the index doesn't even
# exist anymore. need a new checkpoint.
return messages
# proceed with altering the system prompt. keep system prompt,
# if it's there, and add summary to it. summary will become
# system prompt if there is no system prompt.
convo_messages = [
message for message in messages if message.get("role") != "system"
]
system_prompt = next(
(message for message in messages if message.get("role") == "system"), None
)
summary_message = f"Summary of conversation so far:\n\n{checkpoint.summary}"
if system_prompt:
system_prompt["content"] += f"\n\n{summary_message}"
else:
system_prompt = { "role": "system", "content": summary_message }
# drop old messages, reapply system prompt.
messages = self.apply_checkpoint(checkpoint, messages)
print(f"[{self.session_info.chat_id}] Applying summary:\n\n{checkpoint.summary}")
return [system_prompt] + messages
async def send_message(self, message: str):
await self.event_emitter({
"type": "status",
"data": {
"description": message,
"done": True,
},
})
async def set_summarizing_status(self, done: bool, message: Optional[str]=None):
if not self.event_emitter:
return
if not done:
description = (
"Summarizing conversation due to reaching context limit (do not reply yet)."
)
else:
description = (
"Summarization complete (you may now reply)."
)
if message:
description = message
await self.event_emitter({
"type": "status",
"data": {
"description": description,
"done": done,
},
})
def apply_checkpoint(
self, checkpoint: Checkpoint, messages: List[dict]
) -> List[dict]:
"""
Possibly shorten the message context based on a checkpoint.
This works two ways: if the messages have IDs (outlet
filter), split by message ID (very reliable). Otherwise,
attempt to split by on the recorded message index (inlet
filter; not very reliable).
"""
# first attempt to drop everything before the checkpointed
# message id.
split_point = 0
for idx, message in enumerate(messages):
if "id" in message and message["id"] == checkpoint.message_id:
split_point = idx
break
# if we can't find the ID to split on, fall back to message
# index if possible. this can happen during message
# regeneration, for example. or if we're called from the inlet
# filter, which doesn't have access to message ids.
if split_point == 0 and checkpoint.message_index <= len(messages):
split_point = checkpoint.message_index
orig = len(messages)
messages = messages[split_point:]
print((f"[{self.session_info.chat_id}] Dropped context to {len(messages)} "
f"messages (from {orig})"))
return messages
async def handle_nuke(self, body):
checkpointer = Checkpointer(
chat_id=self.session_info.chat_id,
chroma_client=CHROMA_CLIENT
)
checkpointer.nuke_checkpoints()
await self.send_message("Deleted all checkpoint for chat.")
body["messages"][-1]["content"] = (
"Respond ony with: 'Deleted all checkpoints for chat.'"
)
body["messages"] = body["messages"][-1:]
return body
async def outlet(
self,
body: dict,
__user__: Optional[dict],
__model__: Optional[dict],
__event_emitter__: Callable[[Any], Awaitable[None]],
) -> dict:
# Useful things to have around.
self.user = __user__
self.model = __model__
self.session_info = extract_session_info(__event_emitter__)
self.event_emitter = __event_emitter__
self.summarizer_model_id = self.valves.summarizer_model(body)
# global filters apply to requests coming in through proxied
# API. If we're not an OpenWebUI chat, abort mission.
if not self.session_info:
return body
if not self.model or self.model["owned_by"] != "ollama":
return body
messages = body["messages"]
num_ctx = await calculate_num_ctx(
chat_id=self.session_info.chat_id,
user_id=self.user["id"],
model=self.model
)
# apply current checkpoint ONLY for purposes of calculating if
# we have hit num_ctx within current checkpoint.
checkpointer = Checkpointer(
chat_id=self.session_info.chat_id,
chroma_client=CHROMA_CLIENT
)
checkpoint = checkpointer.get_current_checkpoint()
messages_for_ctx_check = (self.apply_checkpoint(checkpoint, messages)
if checkpoint else messages)
hit_limit, amount_over = hit_context_limit(
messages=messages_for_ctx_check,
num_ctx=num_ctx,
wiggle_room=self.valves.wiggle_room
)
if hit_limit:
# we need the FULL message list to do proper summarizing,
# because we might be summarizing a hug context.
await self.create_checkpoint(
messages=messages,
num_ctx=num_ctx,
last_checkpointed_id=checkpoint.message_id if checkpoint else None
)
return body
async def inlet(
self,
body: dict,
__user__: Optional[dict],
__model__: Optional[dict],
__event_emitter__: Callable[[Any], Awaitable[None]]
) -> dict:
# Useful properties to have around.
self.user = __user__
self.model = __model__
self.session_info = extract_session_info(__event_emitter__)
self.event_emitter = __event_emitter__
self.summarizer_model_id = self.valves.summarizer_model(body)
# global filters apply to requests coming in through proxied
# API. If we're not an OpenWebUI chat, abort mission.
if not self.session_info:
return body
if not self.model or self.model["owned_by"] != "ollama":
return body
# super basic external command handling (delete checkpoints).
user_msg = get_last_user_message(body["messages"])
if user_msg and user_msg == "!nuke":
return await self.handle_nuke(body)
# apply current checkpoint to the chat: adds most recent
# summary to system prompt, and drops all messages before the
# checkpoint.
checkpointer = Checkpointer(
chat_id=self.session_info.chat_id,
chroma_client=CHROMA_CLIENT
)
checkpoint = checkpointer.get_current_checkpoint()
if checkpoint:
print((
f"Using checkpoint {checkpoint.slug} for "
f"conversation {self.session_info.chat_id}"
))
body["messages"] = self.update_chat_with_checkpoint(body["messages"], checkpoint)
return body