-
Notifications
You must be signed in to change notification settings - Fork 73
Add decoder custom modeling for inference based on NxD #840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
It is not strictly necessary to use the GenerationMixin to implement the generate method, so it is moved to the child class that uses it.
This removes obsolete or redundant attributes and add explanations about AutoModel registration for pipelines.
This allows to force the NeuronModelForCausalLM factory methods to export/load a specific model with the NxD backend even if an equivalent HLO model exists.
The generate_neff method from torch_neuronx that is used by the ModelBuilder class does not support caching. This patches this method to replace it by a wrapper that uses the caching mechanism that has been implemented in the libneuronxla package.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
NEURON_CONFIG_FILE = "neuron_config.json" | ||
|
||
|
||
def to_torch_dtype(dtype_str: str) -> torch.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are so many mappings more or less the same in Optimum Neuron, couldn't we put them in a utils that everyone could leverage?
optimum-neuron/optimum/neuron/models/inference/hlo/backend/dtypes.py
Lines 19 to 23 in a8b9035
def to_torch_dtype(dtype): mapping = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, optimum-neuron/optimum/neuron/utils/misc.py
Lines 626 to 629 in a8b9035
def map_torch_dtype(dtype: Union[str, torch.dtype]): dtype_mapping = { "bfloat16": torch.bfloat16, "float16": torch.float16, mapper = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"}
|
||
|
||
@register_neuron_config | ||
class NxDNeuronConfig(NeuronConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it the same as the NeuronConfig
of NxDI, do we need all following features?
SHARDED_KERNEL = 2 | ||
|
||
|
||
class NeuronAttentionBase(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment referring the corresponding part in NxDI, so we won't forget.
_traced_qkv_kernel = nki_jit()(rmsnorm_qkv_isa_kernel) | ||
|
||
|
||
class GQA(enum.Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
weight_cache = {} | ||
|
||
|
||
def _get_weight_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
import torch | ||
|
||
|
||
def generate_buckets(min_length: int, max_length: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing | ||
cache storage instead of replicating it on multiple devices (cores). | ||
|
||
Flash decoding lives in context of GQA (group query attention). This means it is a feature on top of GQA and not | ||
traditional MHA (multi-head attention). In GQA we replicate the KV cache in the devices within the same KV group. | ||
Now instead of replicating, we shard the KV and distribute them in each device of the group. To accommodate this setup, | ||
we modify the attention computation as below: | ||
1) Gather all query heads in the group, | ||
2) Compute partial softmax on each device, | ||
3) Reduce-scatter in the end to get the complete result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing | |
cache storage instead of replicating it on multiple devices (cores). | |
Flash decoding lives in context of GQA (group query attention). This means it is a feature on top of GQA and not | |
traditional MHA (multi-head attention). In GQA we replicate the KV cache in the devices within the same KV group. | |
Now instead of replicating, we shard the KV and distribute them in each device of the group. To accommodate this setup, | |
we modify the attention computation as below: | |
1) Gather all query heads in the group, | |
2) Compute partial softmax on each device, | |
3) Reduce-scatter in the end to get the complete result. | |
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing | |
cache storage instead of replicating it on multiple devices (cores). | |
Flash decoding lives in the context of GQA (group query attention). This means it is a feature on top of GQA and not | |
traditional MHA (multi-head attention). In GQA, we replicate the KV cache in the devices within the same KV group. | |
Now, instead of replicating, we shard the KV and distribute them to each device in the group. To accommodate this setup, we modify the attention computation as follows: | |
1) Gather all query heads in the group, | |
2) Compute partial softmax on each device, | |
3) Reduce-scatter in the end to get the complete result. |
Directory to store the cache. If not provided, a default directory will be used. | ||
""" | ||
|
||
def generate_neff_with_cache( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's awesome, did not think of overriding this. Will check how to use it for other traced models :D
What does this PR do?
This adds support for the export and inference of decoder models on top of
neuronx-distributed
.For now only two model architectures are supported:
llama
andmixtral
.Note that the existing custom modeling for llama on top of TnX is still chosen by default when using
NeuronModelForCausalLM
.To export or instantiate a llama model on top of NxD instead, either:
LlamaNxDModelForCausalLM
,NeuronModelForCausalLM
.The basic features are all implemented:
A new cool feature has been added: assisted/speculative decoding.
What's missing:
Performance: