Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trigger do_sample automatically based on temperature for huggingface … #80

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def __init__(

if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
free_in_GB = sum([int(mem / 1024**3) for mem in torch.cuda.mem_get_info()])
free_in_GB = sum(
[int(mem / 1024**3) for mem in torch.cuda.mem_get_info()]
)

logger.log(
logging.INFO, f"Found {n_gpus} GPUs with {free_in_GB}GB free GPU memory"
Expand Down Expand Up @@ -398,9 +400,10 @@ def _generate_batch(
outputs = self.model.generate(
inputs.input_ids.to(self.model.device),
max_new_tokens=max_new_tokens,
temperature=temprature,
temperature=temprature if temprature != 0 else None,
repetition_penalty=repetition_penalty,
return_dict_in_generate=True,
do_sample=temprature != 0,
)
else:
outputs = [
Expand Down
10 changes: 8 additions & 2 deletions alfred/fm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,15 @@ def _process_batch(batch):
return batches


def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]:
def static_batch(
queries: Union[Query, str], batch_size: int = 512
) -> List[List[Query]]:
"""
Static Batching Utility
Batch queries into fixed size batches

:param queries: A list of queries to be batched
:type queries: List[Query]
:type queries: Union[Query, str]
:param batch_sz: The batch size
:type batch_sz: int
:return: A list of batches
Expand All @@ -548,6 +550,10 @@ def static_batch(queries: Query, batch_size: int = 1024) -> List[List[Query]]:
_q = query.load()[0]
elif isinstance(query, RankedQuery):
_q = query.prompt
elif isinstance(query, str):
_q = query
else:
print(f"Unknown query type {type(query)}")
batch.append(_q)
if len(batch) > 0:
batches.append(batch)
Expand Down
10 changes: 5 additions & 5 deletions docs/alfred/fm/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HuggingFaceModel(LocalAccessFoundationModel):

### HuggingFaceModel()._encode_batch

[Show source in huggingface.py:438](../../../alfred/fm/huggingface.py#L438)
[Show source in huggingface.py:441](../../../alfred/fm/huggingface.py#L441)

Encode given batch of instances.

Expand All @@ -71,7 +71,7 @@ def _encode_batch(self, batch_instance, **kwargs) -> List[torch.Tensor]: ...

### HuggingFaceModel()._generate_batch

[Show source in huggingface.py:348](../../../alfred/fm/huggingface.py#L348)
[Show source in huggingface.py:350](../../../alfred/fm/huggingface.py#L350)

Generate completions for a batch of prompts using the model.

Expand Down Expand Up @@ -114,7 +114,7 @@ def _generate_batch(

### HuggingFaceModel()._get_hidden_states

[Show source in huggingface.py:173](../../../alfred/fm/huggingface.py#L173)
[Show source in huggingface.py:175](../../../alfred/fm/huggingface.py#L175)

Get the hidden states of the inputs.
For encoder-decoder models (e.g.) T5, this returns the encoder hidden states.
Expand All @@ -140,7 +140,7 @@ def _get_hidden_states(self, inputs, reduction="mean") -> torch.Tensor: ...

### HuggingFaceModel()._score_batch

[Show source in huggingface.py:212](../../../alfred/fm/huggingface.py#L212)
[Show source in huggingface.py:214](../../../alfred/fm/huggingface.py#L214)

Score a batch of prompts and candidates using the model.

Expand Down Expand Up @@ -180,7 +180,7 @@ def _score_batch(

### HuggingFaceModel().chat

[Show source in huggingface.py:464](../../../alfred/fm/huggingface.py#L464)
[Show source in huggingface.py:467](../../../alfred/fm/huggingface.py#L467)

Launch an interactive chat session

Expand Down
6 changes: 4 additions & 2 deletions docs/alfred/fm/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ Batch queries into fixed size batches
#### Arguments

- `queries` - A list of queries to be batched
:type queries: List[Query]
:type queries: Union[Query, str]
- `batch_sz` - The batch size
:type batch_sz: int

Expand All @@ -392,7 +392,9 @@ Type: *List[List[Query]]*
#### Signature

```python
def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]: ...
def static_batch(
queries: Union[Query, str], batch_size: int = 512
) -> List[List[Query]]: ...
```


Expand Down
Loading