Skip to content

Commit 6350725

Browse files
authored
simpler leaky_relu (tinygrad#9271)
rendered as `*(data0+alu0) = ((val0<0.0f)?(0.01f*val0):val0);` instead of two wheres. possible to update rewrite rules too
1 parent 86b737a commit 6350725

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tinygrad/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3091,7 +3091,7 @@ def leaky_relu(self, neg_slope=0.01):
30913091
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
30923092
```
30933093
"""
3094-
return self.relu() - (-neg_slope*self).relu()
3094+
return (self<0).where(neg_slope*self, self)
30953095

30963096
def mish(self):
30973097
"""

0 commit comments

Comments
 (0)