Skip to content

Add decoder custom modeling for inference based on NxD #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

dacorvo
Copy link
Collaborator

@dacorvo dacorvo commented Apr 30, 2025

What does this PR do?

This adds support for the export and inference of decoder models on top of neuronx-distributed.

For now only two model architectures are supported: llama and mixtral.

Note that the existing custom modeling for llama on top of TnX is still chosen by default when using NeuronModelForCausalLM.

To export or instantiate a llama model on top of NxD instead, either:

  • use directly LlamaNxDModelForCausalLM,
  • export OPTIMUM_PRIORITIZE_NXD_BACKEND=1, then use NeuronModelForCausalLM.

The basic features are all implemented:

  • export/save/reload,
  • generate with multinomial sampling or greedy,
  • transparent local cache,
  • transparent hub cache.

A new cool feature has been added: assisted/speculative decoding.

What's missing:

  • support for continuous batching,
  • integration in TGI,
  • serval modeling optimizations are not working (yet): they will be fixed in individual pull-requests or removed.

Performance:

  • inference time seems in line with the numbers obtained using the direct HLO modeling,
  • device memory usage is much higher, quickly leading to a saturation when increasing the batch size. It might be because some outputs/cached vales are stored with a higher precision.

dacorvo added 12 commits April 17, 2025 12:04
It is not strictly necessary to use the GenerationMixin to implement
the generate method, so it is moved to the child class that uses it.
This removes obsolete or redundant attributes and add explanations
about AutoModel registration for pipelines.
This allows to force the NeuronModelForCausalLM factory methods to
export/load a specific model with the NxD backend even if an
equivalent HLO model exists.
The generate_neff method from torch_neuronx that is used by the ModelBuilder
class does not support caching.
This patches this method to replace it by a wrapper that uses the
caching mechanism that has been implemented in the libneuronxla package.
@dacorvo dacorvo requested a review from tengomucho April 30, 2025 13:58
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

NEURON_CONFIG_FILE = "neuron_config.json"


def to_torch_dtype(dtype_str: str) -> torch.dtype:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are so many mappings more or less the same in Optimum Neuron, couldn't we put them in a utils that everyone could leverage?



@register_neuron_config
class NxDNeuronConfig(NeuronConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the same as the NeuronConfig of NxDI, do we need all following features?

SHARDED_KERNEL = 2


class NeuronAttentionBase(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment referring the corresponding part in NxDI, so we won't forget.

_traced_qkv_kernel = nki_jit()(rmsnorm_qkv_isa_kernel)


class GQA(enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

weight_cache = {}


def _get_weight_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

import torch


def generate_buckets(min_length: int, max_length: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +3 to +12
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing
cache storage instead of replicating it on multiple devices (cores).

Flash decoding lives in context of GQA (group query attention). This means it is a feature on top of GQA and not
traditional MHA (multi-head attention). In GQA we replicate the KV cache in the devices within the same KV group.
Now instead of replicating, we shard the KV and distribute them in each device of the group. To accommodate this setup,
we modify the attention computation as below:
1) Gather all query heads in the group,
2) Compute partial softmax on each device,
3) Reduce-scatter in the end to get the complete result.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing
cache storage instead of replicating it on multiple devices (cores).
Flash decoding lives in context of GQA (group query attention). This means it is a feature on top of GQA and not
traditional MHA (multi-head attention). In GQA we replicate the KV cache in the devices within the same KV group.
Now instead of replicating, we shard the KV and distribute them in each device of the group. To accommodate this setup,
we modify the attention computation as below:
1) Gather all query heads in the group,
2) Compute partial softmax on each device,
3) Reduce-scatter in the end to get the complete result.
Flash decoding supports long context inference by reducing KV cache memory. This is done by sharding and distributing
cache storage instead of replicating it on multiple devices (cores).
Flash decoding lives in the context of GQA (group query attention). This means it is a feature on top of GQA and not
traditional MHA (multi-head attention). In GQA, we replicate the KV cache in the devices within the same KV group.
Now, instead of replicating, we shard the KV and distribute them to each device in the group. To accommodate this setup, we modify the attention computation as follows:
1) Gather all query heads in the group,
2) Compute partial softmax on each device,
3) Reduce-scatter in the end to get the complete result.

Directory to store the cache. If not provided, a default directory will be used.
"""

def generate_neff_with_cache(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's awesome, did not think of overriding this. Will check how to use it for other traced models :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants