-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Connor Holmes <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Ammar Ahmad Awan <[email protected]>
- Loading branch information
1 parent
62d3793
commit 6b40e52
Showing
132 changed files
with
5,496 additions
and
1,034 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# MII Examples | ||
Please see [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/mii) for a few examples on using MII. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from .ragged_batching import MIIAsyncPipeline, MIIPipeline |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
import abc | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
FLOAT_PAD = -float("inf") | ||
|
||
|
||
class BaseLogitProcessor(abc.ABC): | ||
def __call__(self, logits: torch.Tensor) -> torch.Tensor: | ||
return self.forward(logits) | ||
|
||
@abc.abstractmethod | ||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
... | ||
|
||
def get_key(self) -> str: | ||
return self.__class__.__name__ | ||
|
||
|
||
class TopKLogitProcessor(BaseLogitProcessor): | ||
def __init__(self, top_k: int) -> None: | ||
self.top_k = top_k | ||
|
||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
# Remove all tokens with a probability less than the | ||
# last token of the top-k | ||
indices_to_remove = logits < torch.topk(logits, self.top_k)[0][..., -1, None] | ||
logits[indices_to_remove] = FLOAT_PAD | ||
return logits | ||
|
||
def get_key(self) -> str: | ||
return super().get_key() + f"_top_k={self.top_k}" | ||
|
||
|
||
class TopPLogitProcessor(BaseLogitProcessor): | ||
def __init__(self, top_p: float) -> None: | ||
assert 0.0 <= top_p <= 1.0 | ||
self.top_p = top_p | ||
|
||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
# convert to 1D | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | ||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
# Remove tokens with cumulative probability above the threshold | ||
sorted_indices_to_remove = cumulative_probs > self.top_p | ||
# Shift the indices to the right to keep also the first token | ||
# above the threshold | ||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
sorted_indices_to_remove[..., 0] = 0 | ||
for i in range(sorted_indices.size(0)): | ||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | ||
logits[i][indices_to_remove] = FLOAT_PAD | ||
return logits | ||
|
||
def get_key(self) -> str: | ||
return super().get_key() + f"_top_p={self.top_p}" | ||
|
||
|
||
class TemperatureLogitProcessor(BaseLogitProcessor): | ||
def __init__(self, temperature: float) -> None: | ||
self.temperature = temperature | ||
assert self.temperature > 0.0 | ||
|
||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
return logits / self.temperature | ||
|
||
def get_key(self) -> str: | ||
return super().get_key() + f"_temperature={self.temperature}" | ||
|
||
|
||
class PipelineLogitProcessor(BaseLogitProcessor): | ||
def __init__(self, pipeline: List[BaseLogitProcessor]) -> None: | ||
assert all(isinstance(step, BaseLogitProcessor) for step in pipeline) | ||
self.pipeline = pipeline | ||
|
||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
for step in self.pipeline: | ||
logits = step(logits) | ||
return logits | ||
|
||
def get_key(self) -> str: | ||
return super().get_key( | ||
) + f"_{'_'.join(step.get_key() for step in self.pipeline)}" | ||
|
||
|
||
class NucleusSamplingLogitProcessor(BaseLogitProcessor): | ||
def __init__(self, | ||
top_k: Optional[int] = None, | ||
top_p: Optional[float] = None) -> None: | ||
assert top_k is not None or top_p is not None | ||
if top_k is None: | ||
self._processor = TopPLogitProcessor(top_p) | ||
elif top_p is None: | ||
self._processor = TopKLogitProcessor(top_k) | ||
else: | ||
self._processor = PipelineLogitProcessor( | ||
[TopKLogitProcessor(top_k), | ||
TopPLogitProcessor(top_p)]) | ||
|
||
def forward(self, logits: torch.Tensor) -> torch.Tensor: | ||
return self._processor(logits) | ||
|
||
def get_key(self) -> str: | ||
return super().get_key() + f"_{self._processor.get_key()}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
import abc | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch.distributions import Categorical | ||
|
||
|
||
class BaseGenerationSampler(abc.ABC): | ||
@abc.abstractmethod | ||
def __call__( | ||
self, | ||
logits: torch.Tensor, | ||
) -> Tuple[torch.LongTensor, | ||
torch.Tensor]: | ||
""" | ||
Given the logits, return the next token to add to the sequence, as well | ||
as the log probability of the token | ||
Args: | ||
logits (torch.Tensor): | ||
The logits from the model. Shape: (batch_size, vocab_size) | ||
Returns: | ||
Tuple[torch.LongTensor, torch.Tensor]: | ||
The next token to add to the sequence, and the log probability | ||
of the token. Shape: (batch_size,) and (batch_size,) | ||
""" | ||
... | ||
|
||
def get_key(self) -> str: | ||
return self.__class__.__name__ | ||
|
||
|
||
class LogitsSampler(BaseGenerationSampler): | ||
def __call__( | ||
self, | ||
logits: torch.Tensor, | ||
) -> Tuple[torch.LongTensor, | ||
torch.Tensor]: | ||
logits = logits.float() | ||
sampler = Categorical(logits=logits) | ||
next_tokens = sampler.sample() | ||
logprobs = sampler.log_prob(next_tokens) | ||
return next_tokens, logprobs | ||
|
||
|
||
class GreedySampler(BaseGenerationSampler): | ||
def __call__(self, logits: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]: | ||
logits = logits.float() | ||
sampler = Categorical(logits=logits) | ||
next_tokens = logits.argmax(dim=-1) | ||
logprobs = sampler.log_prob(next_tokens) | ||
return next_tokens, logprobs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
import abc | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
# from megatron import get_tokenizer | ||
# from megatron.tokenizer.tokenizer import AbstractTokenizer | ||
|
||
|
||
class BaseGenerationStopCriterion(abc.ABC): | ||
def __init__(self, tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
return self.forward(tokens) | ||
|
||
@abc.abstractmethod | ||
def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
... | ||
|
||
def get_key(self) -> str: | ||
return self.__class__.__name__ | ||
|
||
|
||
class TokenStopCriterion(BaseGenerationStopCriterion): | ||
def __init__(self, token: Union[str, int], tokenizer) -> None: | ||
super().__init__(tokenizer=tokenizer) | ||
if isinstance(token, str): | ||
token_id = self.tokenizer.tokenize(token)[0] | ||
else: | ||
token_id = token | ||
self.stop_token_id = token_id | ||
|
||
def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
retval = torch.zeros_like(tokens, dtype=torch.bool) | ||
retval |= tokens == self.stop_token_id | ||
return retval | ||
|
||
def get_key(self) -> str: | ||
return self.__class__.__name__ + f"_token_id={self.stop_token_id}" | ||
|
||
|
||
class EosGenerationStopCriterion(BaseGenerationStopCriterion): | ||
def __init__(self, tokenizer): | ||
super().__init__(tokenizer=tokenizer) | ||
if hasattr(self.tokenizer, "eod"): | ||
self.eos_id = self.tokenizer.eod | ||
elif hasattr(self.tokenizer, "eos_token_id"): | ||
self.eos_id = self.tokenizer.eos_token_id | ||
elif hasattr(self.tokenizer, "eos_token"): | ||
self.eos_id = self.tokenizer.eos_token | ||
else: | ||
raise ValueError( | ||
"Tokenizer must have either an `eod` or `eos_token` attribute.") | ||
|
||
def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
return tokens == self.eos_id | ||
|
||
|
||
class NewLineDelimitedStopCriterion(BaseGenerationStopCriterion): | ||
def __init__(self, tokenizer): | ||
super().__init__(tokenizer=tokenizer) | ||
self.stop_token_ids = list( | ||
set([self.tokenizer.tokenize(x)[0] for x in ["\n", | ||
"\r\n", | ||
"\n\n", | ||
".\n\n"]])) | ||
|
||
def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
retval = torch.zeros_like(tokens, dtype=torch.bool) | ||
for stop_token_id in self.stop_token_ids: | ||
retval |= tokens == stop_token_id | ||
return retval | ||
|
||
|
||
class PipelinedCriterion(BaseGenerationStopCriterion): | ||
def __init__( | ||
self, | ||
criteria: List[BaseGenerationStopCriterion], | ||
tokenizer, | ||
): | ||
super().__init__(tokenizer=tokenizer) | ||
self.criteria = criteria | ||
|
||
def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor: | ||
retval = torch.zeros_like(tokens, dtype=torch.bool) | ||
for criterion in self.criteria: | ||
retval |= criterion(tokens) | ||
return retval | ||
|
||
def get_key(self) -> str: | ||
return super().get_key( | ||
) + f"_{'_'.join(criterion.get_key() for criterion in self.criteria)}" |
Oops, something went wrong.