@@ -80,9 +80,6 @@ pub trait DType:
80
80
fn spowf ( self , rhs : Self ) -> Self ;
81
81
fn rem_euclid ( self , rhs : Self ) -> Self ;
82
82
83
- fn sqrt ( self ) -> Self ;
84
- fn cbrt ( self ) -> Self ;
85
-
86
83
fn abs ( self ) -> Self ;
87
84
fn trunc ( self ) -> Self ;
88
85
fn max ( self , other : Self ) -> Self ;
@@ -94,6 +91,19 @@ pub trait DType:
94
91
fn to_radians ( self ) -> Self ;
95
92
fn atan2 ( self , rhs : Self ) -> Self ;
96
93
94
+ fn sqrt ( self ) -> Self {
95
+ self . powf ( ( 1.0 / 2.0 ) . to_dt ( ) )
96
+ }
97
+ fn cbrt ( self ) -> Self {
98
+ self . powf ( ( 1.0 / 3.0 ) . to_dt ( ) )
99
+ }
100
+ fn ssqrt ( self ) -> Self {
101
+ self . spowf ( ( 1.0 / 2.0 ) . to_dt ( ) )
102
+ }
103
+ fn scbrt ( self ) -> Self {
104
+ self . spowf ( ( 1.0 / 3.0 ) . to_dt ( ) )
105
+ }
106
+
97
107
fn _fma ( self , mul : Self , add : Self ) -> Self ;
98
108
/// Fused multiply-add if "fma" is enabled in rustc
99
109
fn fma ( self , mul : Self , add : Self ) -> Self {
@@ -121,12 +131,6 @@ macro_rules! impl_float {
121
131
fn rem_euclid( self , rhs: Self ) -> Self {
122
132
self . rem_euclid( rhs)
123
133
}
124
- fn sqrt( self ) -> Self {
125
- self . sqrt( )
126
- }
127
- fn cbrt( self ) -> Self {
128
- self . cbrt( )
129
- }
130
134
fn abs( self ) -> Self {
131
135
self . abs( )
132
136
}
@@ -154,6 +158,13 @@ macro_rules! impl_float {
154
158
fn atan2( self , rhs: Self ) -> Self {
155
159
self . atan2( rhs)
156
160
}
161
+ fn sqrt( self ) -> Self {
162
+ self . sqrt( )
163
+ }
164
+ // 50% slower than powf/spowf?
165
+ //fn cbrt(self) -> Self {
166
+ // self.cbrt()
167
+ //}
157
168
fn _fma( self , mul: Self , add: Self ) -> Self {
158
169
self . mul_add( mul, add)
159
170
}
@@ -1133,7 +1144,7 @@ where
1133
1144
Channels < N > : ValidChannels ,
1134
1145
{
1135
1146
let mut lms = matmul3t ( [ pixel[ 0 ] , pixel[ 1 ] , pixel[ 2 ] ] , OKLAB_M1 ) ;
1136
- lms. iter_mut ( ) . for_each ( |c| * c = c. cbrt ( ) ) ;
1147
+ lms. iter_mut ( ) . for_each ( |c| * c = c. scbrt ( ) ) ;
1137
1148
[ pixel[ 0 ] , pixel[ 1 ] , pixel[ 2 ] ] = matmul3t ( lms, OKLAB_M2 ) ;
1138
1149
}
1139
1150
0 commit comments