You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have searched the YOLOv6 issues and found no similar feature requests.
Description
YOLOv6 currently uses deprecated PyTorch API calls for automatic mixed precision (AMP) training. This causes deprecation warnings when using PyTorch 2.0 or newer.
When training with PyTorch 2.0+, the following warnings appear: FutureWarning: torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead. FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
Proposed Solution
Update the following code patterns throughout the codebase:
Replace torch.cuda.amp.GradScaler() with torch.amp.GradScaler('cuda')
Replace torch.cuda.amp.autocast() with torch.amp.autocast('cuda')
This change maintains the same functionality while making the code compatible with PyTorch 2.0+ and eliminating deprecation warnings.
Files that need changes
yolov6/core/engine.py
yolov6/models/losses/loss.py
I'm happy to contribute a PR for this change if desired.
Use case
No response
Additional
No response
Are you willing to submit a PR?
Yes I'd like to help by submitting a PR!
The text was updated successfully, but these errors were encountered:
…lity meituan#1077
- Replace torch.cuda.amp with torch.amp
- Add device specification ('cuda') to amp.autocast and GradScaler
- Update imports to use torch.amp directly
Search before asking
Description
YOLOv6 currently uses deprecated PyTorch API calls for automatic mixed precision (AMP) training. This causes deprecation warnings when using PyTorch 2.0 or newer.
When training with PyTorch 2.0+, the following warnings appear:
FutureWarning: torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.
FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
Proposed Solution
Update the following code patterns throughout the codebase:
torch.cuda.amp.GradScaler()
withtorch.amp.GradScaler('cuda')
torch.cuda.amp.autocast()
withtorch.amp.autocast('cuda')
This change maintains the same functionality while making the code compatible with PyTorch 2.0+ and eliminating deprecation warnings.
Files that need changes
I'm happy to contribute a PR for this change if desired.
Use case
No response
Additional
No response
Are you willing to submit a PR?
The text was updated successfully, but these errors were encountered: