diff --git a/demo/inference.py b/demo/inference.py index efff86a7..4ff05bab 100644 --- a/demo/inference.py +++ b/demo/inference.py @@ -223,7 +223,7 @@ def main(): if cfg.TEST.MODEL_FILE: print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) - pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE, map_location=CTX), strict=False) else: print('expected model defined in config at TEST.MODEL_FILE')