Replies: 1 comment 1 reply
-
@apbose is there a way to create a hard subset of complex that we can support easily and grow from there? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Complex number handling in Torch-TensorRT
TL;DR
This RFC proposes the addition of complex number support in Torch-TensorRT. TensorRT does not support complex numbers, but with the use of rotary embeddings in positional embeddings, complex numbers play an important role on how these embeddings are applied.
Goal
To support the multi-GPU example of Llama 3 model running end to end
Use case
Through this feature we intend to demonstrate the end to end forward pass of torchTRT compiled llama3 distributed model in multi GPU. Below illustrates how complex numbers are inputs to the llama3 model
The query and key vectors are viewed as complex, while the freq vectors are computed in the polar form with complex frequency.
The reason we encounter this only for distributed examples is because when we compile the model using
torch.compile(distributed_model, backend = torch_tensorrt)
The distributed tensors are hoisted to inputs when model is wrapped with
aot_autograd
leading to complex inputs to torchTRT compiled graph.Ref- pytorch/pytorch#136289
Implementation Stages
Complex unpacking
Convert the complex numbers into a tuple of real and imaginary parts. Complex number denoted by x+iy, should be provided as input in the form of (x,y)
This involves modifying the meta data shape and data type of the complex nodes. Also the subsequent operations with these complex numbers as input
Numeric truncation
In the above complex64 should be unpacked to a tuple of float32. Similarly complex128 should be unpacked to a tuple of float32. For which the truncate_flag has to be used
Function signature modification
Identify the boundary of the operations affected by the complex inputs. Below is an example of how it looks like in llama3 model for the rotary embedding operation
eg:
The signature of these complex operations needs to be modified so that there are no graph breaks, and it handles the complex unpacking also
Unification of pre_lowering and post_lowering pass for distributed and non distributed
The
pre_lowering
andpost_lowering
needs to uniform across both distributed and non distributed cases.Diagram
In the above there has to be additional handling in the torch TRT runtime. All the above will be called via an API in the post lowering passes.
API changes
Detection stage
torch_tensorrt/dynamo/lowering/passes/pass_utils.py
torch_tensorrt/dynamo/utils.py
Decomposition stage
torch_tensorrt/dynamo/lowering/passes/reshape_complex_placeholder_nodes
'Graph Rewrite stage
torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite
All the above needs to be called sequentially in the
torch_tensorrt/dynamo/backend/backends.py
Further to be explored are the changes in the runtimes in
_PythonTorchTensorRTModule.py
and_TorchTensorRTModule.py
since we are modifying the inputsBeta Was this translation helpful? Give feedback.
All reactions