diff --git a/tllib/modules/grl.py b/tllib/modules/grl.py index fb07e915..0ea7d154 100644 --- a/tllib/modules/grl.py +++ b/tllib/modules/grl.py @@ -68,7 +68,7 @@ def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: def forward(self, input: torch.Tensor) -> torch.Tensor: """""" - coeff = np.float( + coeff = float( 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) - (self.hi - self.lo) + self.lo )