Skip to content

Commit

Permalink
updated agent studio
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Jun 7, 2024
1 parent ac9e4e7 commit b0be327
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 58 deletions.
6 changes: 4 additions & 2 deletions nano_llm/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ class Plugin(threading.Thread):
drop_inputs (bool): if true, only the most recent input in the queue will be used
threaded (bool): if true, will spawn independent thread for processing the queue.
"""
def __init__(self, name=None, inputs=1, outputs=1, relay=False,
drop_inputs=False, threaded=True, **kwargs):
def __init__(self, name=None, title=None, inputs=1, outputs=1,
relay=False, drop_inputs=False, threaded=True, **kwargs):
"""
Initialize plugin
"""
super().__init__(daemon=True)

self.name = name if name else self.__class__.__name__
self.title = title
self.relay = relay
self.drop_inputs = drop_inputs
self.threaded = threaded
Expand Down Expand Up @@ -437,6 +438,7 @@ def state_dict(self, config=False, **kwargs):
self.reorder_parameters()

state.update({
'title': self.title if self.title else self.name,
'inputs': self.input_names,
'outputs': self.output_names,
'connections': connections,
Expand Down
6 changes: 4 additions & 2 deletions nano_llm/plugins/audio/audio_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchaudio

from nano_llm import Plugin
from nano_llm.utils import convert_audio, resample_audio, find_audio_device, pyaudio_dtype
from nano_llm.utils import convert_audio, resample_audio, audio_db, find_audio_device, pyaudio_dtype


class AudioInputDevice(Plugin):
Expand All @@ -29,7 +29,7 @@ def __init__(self, audio_input_device: int = None, audio_input_channels: int = 1
or None to use the device's default sampling rate.
audio_chunk (float): The duration of time or number of audio samples captured per batch.
"""
super().__init__(input_channels=0, **kwargs)
super().__init__(inputs=0, outputs='audio', **kwargs)

self.pa = pyaudio.PyAudio()

Expand Down Expand Up @@ -115,6 +115,8 @@ def capture(self):
logging.warning(f"resampled input audio from device {self.device_id} has {len(samples)} samples, but expected {expected_samples} samples")
self._resample_warning = True

db = audio_db(samples)
self.send_stats(audio_db=db, summary=f"{db:.1f}dB")
#logging.debug(f"captured {len(samples)} audio samples from audio device {self.device_id} (dtype={samples.dtype})")
return samples

Expand Down
4 changes: 2 additions & 2 deletions nano_llm/plugins/audio/audio_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, audio_output_device: int = None, audio_output_channels: int =
audio_output_channels (int): 1 for mono, 2 for stereo.
sample_rate_hz (int): Sample rate of any outgoing audio (typically 16000, 44100, 48000)
"""
super().__init__(output_channels=0, **kwargs)
super().__init__(outputs=0, **kwargs)

self.pa = pyaudio.PyAudio()

Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(self, audio_output_file: str = '/data/audio/output.wav', audio_outp
audio_output_channels (int): 1 for mono, 2 for stereo.
sample_rate_hz (int): Sample rate of any outgoing audio (typically 16000, 44100, 48000)
"""
super().__init__(output_channels=0, **kwargs)
super().__init__(outputs=0, **kwargs)

self.pa = pyaudio.PyAudio()

Expand Down
20 changes: 15 additions & 5 deletions nano_llm/plugins/auto_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AutoPrompt(Plugin):
Apply prompting templates to incoming streams that form a kind of script.
For example, "<img> Describe the image" will insert the most recent image.
"""
def __init__(self, template : str = '<image> Describe the image concisely.', **kwargs):
def __init__(self, template : str = '<image> Describe the image concisely. <reset>', **kwargs):
"""
Apply prompting templates to incoming streams that form a kind of script.
For example, "<img> Describe the image" will insert the most recent image.
Expand All @@ -38,6 +38,9 @@ def template(self):

@template.setter
def template(self, template):
template = template.replace('/reset', '<reset>')
template = template.replace('/pop', '<pop>')

for tag, aliases in self.tags.items():
for alias in aliases:
template = template.replace(f"<{alias}>", f"<{tag}>")
Expand Down Expand Up @@ -85,13 +88,20 @@ def check_depth(tag):

for i, text in enumerate(template.split('<image>')):
if text:
msg.append(text)
cmd_split = text.split('<reset>')
if len(cmd_split) > 1:
for cmd_text in cmd_split:
if cmd_text:
msg.append(cmd_text)
msg.append('<reset>')
else:
msg.append(text)
if i < len(self.vars['image'].queue):
msg.append(self.vars['image'].queue[i])

from pprint import pprint
print('AUTOPROMPT')
pprint(msg, indent=2)
#from pprint import pprint
#print('AUTOPROMPT')
#pprint(msg, indent=2)

self.output(msg)

23 changes: 15 additions & 8 deletions nano_llm/plugins/chat_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
import os
import time
import logging

Expand Down Expand Up @@ -30,8 +31,8 @@ class ChatSession(Plugin):

def __init__(self, model : str = "princeton-nlp/Sheared-LLaMA-2.7B-ShareGPT",
api : str = "mlc", quantization : str = "q4f16_ft",
max_context_len : int = None, chat_template : str = None,
system_prompt : str = None, **kwargs):
max_context_len : int = None, drop_inputs : bool = False,
chat_template : str = None, system_prompt : str = None, **kwargs):
"""
Load an LLM and run generation on chat requests.
Expand All @@ -40,10 +41,11 @@ def __init__(self, model : str = "princeton-nlp/Sheared-LLaMA-2.7B-ShareGPT",
api (str): The model backend API to use: 'mlc', 'awq', or 'hf' (by default, it will attempt to be automatically determined)
quantization (str): For MLC, 'q4f16_ft', 'q4f16_1', 'q8f16_ft', 'q8f16_1'. For AWQ, the path to the fully-quantized AWQ weights.
max_context_len (str): The maximum chat length in tokens (by default, inherited from the model)
drop_inputs (bool): If true, only the latest message from the input queue will be used (older messages dropped)
chat_template (str|dict): The chat template (by default, will attempt to determine from model type)
system_prompt (str): Set the system prompt (changing this will reset the chat)
"""
super().__init__(outputs=['delta', 'partial', 'final', 'words', 'embed'], **kwargs)
super().__init__(outputs=['delta', 'partial', 'final', 'words', 'embed'], drop_inputs=drop_inputs, **kwargs)

load_time = time.perf_counter()

Expand All @@ -59,18 +61,20 @@ def __init__(self, model : str = "princeton-nlp/Sheared-LLaMA-2.7B-ShareGPT",
self.model = model
self.model_name = self.config.name

self.history = ChatHistory(self.model, chat_template=chat_template, system_prompt=system_prompt, **kwargs)

self.functions = None
self.title = os.path.basename(self.model_name)
self.stream = None
self.functions = None

self.history = ChatHistory(self.model, chat_template=chat_template, system_prompt=system_prompt, **kwargs)

self.add_parameter('max_new_tokens', type=int, default=kwargs.get('max_new_tokens', 128), help="The number of tokens to output in addition to the prompt.")
self.add_parameter('min_new_tokens', type=int, default=kwargs.get('min_new_tokens', -1), help="Force the model to generate a set number of output tokens (<0 to disable)")
self.add_parameter('do_sample', type=bool, default=kwargs.get('do_sample', False), help="If true, temperature/top_p sampling will be used over the logits.")
self.add_parameter('temperature', type=float, default=kwargs.get('temperature', 0.7), help="Randomness token sampling parameter (only used if do_sample=true)")
self.add_parameter('top_p', type=float, default=kwargs.get('top_p', 0.95), help="If set to < 1 and do_sample=True, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept.")
self.add_parameter('repetition_penalty', type=float, default=kwargs.get('repetition_penalty', 1.0), help="The parameter for repetition penalty. 1.0 means no penalty")
self.add_parameter('system_prompt', name='System Prompt', default=system_prompt)
self.add_parameter('drop_inputs', default=drop_inputs)
self.add_parameter('system_prompt', default=system_prompt)

self.max_context_len = self.model.config.max_length
self.wrap_tokens = kwargs.get('wrap_tokens', 512)
Expand Down Expand Up @@ -117,6 +121,8 @@ def type_hints(cls):
"meta-llama/Meta-Llama-3-8B-Instruct",
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"princeton-nlp/Sheared-LLaMA-2.7B-ShareGPT",
"Efficient-Large-Model/VILA1.5-3b",
"Efficient-Large-Model/Llama-3-VILA1.5-8B",
]
},
'system_prompt': {
Expand Down Expand Up @@ -161,8 +167,9 @@ def process(self, input, **kwargs):
# handle some special commands
if isinstance(input, str):
x = input.lower()
if any([x == y for y in ('/reset', '/clear', 'reset', 'clear')]):
if any([x == y for y in ('/reset', '/clear', '<reset>', '<clear>')]):
self.history.reset()
self.send_state()
return

# add prompt to chat history
Expand Down
42 changes: 33 additions & 9 deletions nano_llm/plugins/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,31 @@ class RateLimit(Plugin):
It can also chunk indexable outputs into smaller amounts of data at a time.
"""
def __init__(self, rate : int = None, chunk : int = None, **kwargs):
def __init__(self, rate : float = None, chunk : int = None,
drop : bool = False, on_demand : bool = False, **kwargs):
"""
Rate limiter plugin with the ability to pause/resume from the queue.
Args:
rate (int): The number of items per second that can be transmitted.
rate (float): The number of items per second that can be transmitted (or the playback factor for audio)
chunk (int): For indexable inputs, the maximum number of items
that can be in each output message (if None, no chunking)
that can be in each output message.
drop (bool): If true, only the most recent inputs will be transmitted, with older inputs being dropped.
Otherwise, the queue will continue to grow and be throttled to the given rate.
on_demand (bool): If true, outputs will only be sent when the reciever's input queues
are empty and ready for more data. This will effectively rate limit to the
downstream processing speed.
"""
super().__init__(**kwargs)
super().__init__(outputs='items', drop_inputs=drop, **kwargs)

self.paused = -1

self.tx_rate = 0
self.last_time = time.perf_counter()

self.add_parameter('rate', default=rate)
self.add_parameter('chunk', default=chunk)

self.add_parameter('drop_inputs', name='Drop', default=drop, kwarg='drop')
self.add_parameter('on_demand', default=on_demand)

def process(self, input, sample_rate=None, **kwargs):
"""
Expand All @@ -49,29 +58,44 @@ def process(self, input, sample_rate=None, **kwargs):
time.sleep(pause_duration)
continue

if isinstance(self.rate, float) and sample_rate is not None:
if self.rate < 16 and sample_rate is not None:
rate = self.rate * sample_rate
else:
rate = self.rate

if self.chunk > 0:
if self.chunk is not None and self.chunk > 0:
#logging.debug(f"RateLimit chunk {len(input)} {self.chunk} {time.perf_counter()}")
if len(input) > self.chunk:
self.output(input[:self.chunk], sample_rate=sample_rate, **kwargs)
self.update_stats()
input = input[self.chunk:]
time.sleep(self.chunk/rate*0.95)
new=False
continue
else:
self.output(input, sample_rate=sample_rate, **kwargs)
self.update_stats()
time.sleep(len(input)/rate*0.95)
return
else:
self.output(input, sample_rate=sample_rate, **kwargs)
self.update_stats()
if self.rate > 0:
time.sleep(1.0/self.rate)
return


def update_stats(self):
"""
Calculate and send the throughput statistics when new outputs are transmitted.
"""
curr_time = time.perf_counter()
elapsed_time = curr_time - self.last_time
self.tx_rate = (self.tx_rate * 0.5) + ((1.0 / elapsed_time) * 0.5)
self.last_time = curr_time
self.send_stats(
summary=[f"{self.tx_rate:.1f} tx/sec"],
)

def pause(self, duration=None, until=None):
"""
Pause audio playback for `duration` number of seconds, or until the end time.
Expand Down
Loading

0 comments on commit b0be327

Please sign in to comment.